Compare commits

..

1 Commits

Author SHA1 Message Date
Nader Arbabian 4ac51947b4 fix pyright errors + revert to old way of handling cancelled api requests 2025-07-17 15:29:41 -07:00
12 changed files with 83 additions and 132 deletions
+43 -45
View File
@@ -126,7 +126,7 @@ class Backend:
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload) self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500) return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
@@ -141,6 +141,7 @@ class Backend:
else: else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}") log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try: try:
start_time = time.time()
response = await self.__call_api(handler=handler, payload=payload) response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status status_code = response.status
log.debug( log.debug(
@@ -152,17 +153,19 @@ class Backend:
) )
) )
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_success(workload=workload) self.metrics._request_end(
workload=workload,
req_response_time=time.time() - start_time,
reqnum=auth_data.reqnum,
)
return res return res
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(workload=workload) self.metrics._request_errored(
workload=workload, reqnum=auth_data.reqnum
)
return web.Response(status=500) return web.Response(status=500)
finally: finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release() self.sem.release()
########### ###########
@@ -183,6 +186,12 @@ class Backend:
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally:
if request.task.cancelled():
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(
workload=workload, reqnum=auth_data.reqnum
)
async def __healthcheck(self): async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint health_check_url = self.benchmark_handler.healthcheck_endpoint
@@ -280,52 +289,41 @@ class Backend:
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
log.debug("Initial run to trigger model loading...")
payload = self.benchmark_handler.make_benchmark_payload()
await self.__call_api(handler=self.benchmark_handler, payload=payload)
max_throughput = 0 max_throughput = 0
last_throughput = 0
sum_throughput = 0 sum_throughput = 0
concurrent_requests = 10 if self.allow_parallel_requests else 1 for run in range(self.benchmark_handler.benchmark_runs + 1):
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time() start = time.time()
tasks = [] payload = self.benchmark_handler.make_benchmark_payload()
total_workload = 0 res = await self.__call_api(
handler=self.benchmark_handler, payload=payload
for _ in range(concurrent_requests):
payload = self.benchmark_handler.make_benchmark_payload()
total_workload += payload.count_workload()
tasks.append(
self.__call_api(handler=self.benchmark_handler, payload=payload)
)
responses = await gather(*tasks)
time_elapsed = time.time() - start
throughput = total_workload / time_elapsed
sum_throughput += throughput
max_throughput = max(max_throughput, throughput)
# Log results for debugging
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {len([r for r in responses if r.status == 200])}",
"#" * 60,
]
)
) )
data = await res.json()
time_elapsed = time.time() - start
# first run triggers one-time loading of the model which is very slow, so we skip counting it
if run == 0:
continue
else:
workload = payload.count_workload()
last_throughput = workload / time_elapsed
sum_throughput += last_throughput
max_throughput = max(max_throughput, last_throughput)
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
"",
f"response: {data}",
"#" * 60,
]
)
)
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
log.debug( log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}" f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
) )
# save max_throughput so we don't have to run benchmark again on restart of cold instances
with open(BENCHMARK_INDICATOR_FILE, "w") as f: with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput)) f.write(str(max_throughput))
return max_throughput return max_throughput
+4 -7
View File
@@ -8,6 +8,7 @@ from aiohttp import web, ClientResponse
import inspect import inspect
import psutil import psutil
import requests
""" """
@@ -205,13 +206,13 @@ class ModelMetrics:
workload_received: float workload_received: float
workload_cancelled: float workload_cancelled: float
workload_errored: float workload_errored: float
# these are not
workload_pending: float workload_pending: float
# these are not
cur_perf: float
error_msg: Optional[str] error_msg: Optional[str]
max_throughput: float max_throughput: float
requests_recieved: Set[int] = field(default_factory=set) requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set) requests_working: Set[int] = field(default_factory=set)
last_update: float = field(default_factory=time.time)
@classmethod @classmethod
def empty(cls): def empty(cls):
@@ -220,15 +221,12 @@ class ModelMetrics:
workload_served=0.0, workload_served=0.0,
workload_cancelled=0.0, workload_cancelled=0.0,
workload_errored=0.0, workload_errored=0.0,
cur_perf=0.0,
workload_received=0.0, workload_received=0.0,
error_msg=None, error_msg=None,
max_throughput=0.0, max_throughput=0.0,
) )
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property @property
def workload_processing(self) -> float: def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0) return max(self.workload_received - self.workload_cancelled, 0.0)
@@ -242,7 +240,6 @@ class ModelMetrics:
self.workload_received = 0 self.workload_received = 0
self.workload_cancelled = 0 self.workload_cancelled = 0
self.workload_errored = 0 self.workload_errored = 0
self.last_update = time.time()
@dataclass @dataclass
+13 -11
View File
@@ -46,31 +46,33 @@ class Metrics:
self.model_metrics.requests_recieved.add(reqnum) self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum) self.model_metrics.requests_working.add(reqnum)
def _request_end(self, workload: float, reqnum: int) -> None: def _request_end(
self, workload: float, req_response_time: float, reqnum: int
) -> None:
""" """
this function is called after handling of a request ends, regardless of the outcome this function is called after a response from model API is received.
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
def _request_success(self, workload: float) -> None:
"""
this function is called after a response from model API is received and forwarded.
""" """
self.model_metrics.workload_served += workload self.model_metrics.workload_served += workload
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.cur_perf = workload / req_response_time
self.update_pending = True self.update_pending = True
def _request_errored(self, workload: float) -> None: def _request_errored(self, workload: float, reqnum: int) -> None:
""" """
this function is called if model API returns an error this function is called if model API returns an error
""" """
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_errored += workload self.model_metrics.workload_errored += workload
self.model_metrics.requests_working.discard(reqnum)
def _request_canceled(self, workload: float) -> None: def _request_canceled(self, workload: float, reqnum: int) -> None:
""" """
this function is called if client drops connection before model API has responded this function is called if client drops connection before model API has responded
""" """
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_cancelled += workload self.model_metrics.workload_cancelled += workload
self.model_metrics.requests_working.discard(reqnum)
async def _send_metrics_loop(self) -> Awaitable[NoReturn]: async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True: while True:
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app) runner = web.AppRunner(app, handler_cancellation=True)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
-3
View File
@@ -10,7 +10,6 @@ from collections import Counter
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from urllib.parse import urljoin from urllib.parse import urljoin
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import requests import requests
from lib.data_types import AuthData, ApiPayload from lib.data_types import AuthData, ApiPayload
@@ -121,11 +120,9 @@ class ClientState:
self.url = worker_address self.url = worker_address
url = urljoin(worker_address, self.worker_endpoint) url = urljoin(worker_address, self.worker_endpoint)
self.status = ClientStatus.Generating self.status = ClientStatus.Generating
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
if response.status_code != 200: if response.status_code != 200:
self.infer_error.append( self.infer_error.append(
+2 -2
View File
@@ -1,4 +1,4 @@
aiohttp[speedups]==3.10.1 aiohttp==3.10.1
anyio~=4.4 anyio~=4.4
lib~=4.0 lib~=4.0
nltk~=3.9 nltk~=3.9
@@ -6,5 +6,5 @@ psutil~=6.0
pycryptodome~=3.20 pycryptodome~=3.20
Requests~=2.32 Requests~=2.32
transformers~=4.52 transformers~=4.52
utils==1.0.* utils~=1.0
hf_transfer>=0.1.9 hf_transfer>=0.1.9
+1 -6
View File
@@ -30,12 +30,7 @@ class Endpoint:
Returns: Returns:
Endpoint API key if successful, None otherwise Endpoint API key if successful, None otherwise
""" """
endpoints = { vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
"alpha": "alpha",
"candidate": "candidate",
"prod": "console",
}
vast_console_url = f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
headers = {"Authorization": f"Bearer {account_api_key}"} headers = {"Authorization": f"Bearer {account_api_key}"}
try: try:
-15
View File
@@ -1,15 +0,0 @@
import tempfile
from functools import cache
import requests
@cache
def get_cert_file_path():
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
response = requests.get(cert_url)
response.raise_for_status()
# Use a temporary file that is not deleted on close
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
f.write(response.content)
return f.name
-3
View File
@@ -5,7 +5,6 @@ import requests
from lib.test_utils import print_truncate_res from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
""" """
NOTE: this client example uses a custom comfy workflow compatible with SD3 only NOTE: this client example uses a custom comfy workflow compatible with SD3 only
@@ -52,7 +51,6 @@ def call_default_workflow(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
@@ -143,7 +141,6 @@ def call_custom_workflow_for_sd3(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
+17 -26
View File
@@ -6,7 +6,6 @@ from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List from typing import Dict, Any, Optional, Iterator, Union, List
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig from .data_types.client import CompletionConfig, ChatCompletionConfig
logging.basicConfig( logging.basicConfig(
@@ -29,16 +28,24 @@ class APIClient:
DEFAULT_TIMEOUT = 4 DEFAULT_TIMEOUT = 4
def __init__( def __init__(
self, self, endpoint_group_name: str, api_key: str, server_url: str, instance: str
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
): ):
self.endpoint_group_name = endpoint_group_name self.endpoint_group_name = endpoint_group_name
self.api_key = api_key self.api_key = api_key
self.server_url = server_url self.server_url = server_url
self.endpoint_api_key = endpoint_api_key self.instance = instance
self.endpoint_api_key = self._get_endpoint_api_key()
def _get_endpoint_api_key(self) -> Optional[str]:
"""Get the endpoint API key"""
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
instance=self.instance,
)
if not endpoint_api_key:
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
return endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service""" """Get worker URL and auth data from routing service"""
@@ -91,13 +98,9 @@ class APIClient:
# Make the request using the specified method # Make the request using the specified method
if method.upper() == "POST": if method.upper() == "POST":
response = requests.post( response = requests.post(url, json=req_data, stream=stream)
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET": elif method.upper() == "GET":
response = requests.get( response = requests.get(url, params=req_data, stream=stream)
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else: else:
raise ValueError(f"Unsupported HTTP method: {method}") raise ValueError(f"Unsupported HTTP method: {method}")
@@ -551,24 +554,12 @@ def main():
sys.exit(1) sys.exit(1)
try: try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client # Create the core API client
client = APIClient( client = APIClient(
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url, server_url=args.server_url,
endpoint_api_key=endpoint_api_key, instance=args.instance,
) )
# Create tool manager and demo (passing the model parameter) # Create tool manager and demo (passing the model parameter)
+1 -7
View File
@@ -124,12 +124,7 @@ class CompletionsData(GenericData):
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
test_input = { test_input = {"model": model, "prompt": prompt, "temperature": 0.7}
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input) return cls(input=test_input)
@@ -163,7 +158,6 @@ class ChatCompletionsData(GenericData):
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 500,
} }
return cls(input=test_input) return cls(input=test_input)
+1 -6
View File
@@ -4,7 +4,6 @@ import json
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
@@ -43,11 +42,7 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
req_data = dict(payload=payload, auth_data=auth_data) req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT) url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}") print(f"url: {url}")
response = requests.post( response = requests.post(url, json=req_data)
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status() response.raise_for_status()
res = response.json() res = response.json()
print(res) print(res)