Compare commits

..

3 Commits

Author SHA1 Message Date
Nader Arbabian 9773e5f67b download vast.ai's root certificate in order to make pyworker requests 2025-07-31 12:47:12 -07:00
Rob Ballantyne e0be45f39a Addresses breaking change in core pyworker (#22)
* Addresses breaking change in test_utils.py

Endpoint.get_endpoint_api_key() now requires instance

Moves the call to this function out of the APIClient and into main

* Ensure make_benchmark_payload has a value to calculate the workload

---------

Co-authored-by: Nader Arbabian <nader@vast.ai>
2025-07-18 16:11:10 -07:00
Nader Arbabian be2aafdb1f fix pyright errors + revert to old way of handling cancelled api requests (#23) 2025-07-17 16:59:06 -07:00
7 changed files with 61 additions and 20 deletions
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app, handler_cancellation=True) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
+3
View File
@@ -10,6 +10,7 @@ from collections import Counter
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from urllib.parse import urljoin from urllib.parse import urljoin
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import requests import requests
from lib.data_types import AuthData, ApiPayload from lib.data_types import AuthData, ApiPayload
@@ -120,9 +121,11 @@ class ClientState:
self.url = worker_address self.url = worker_address
url = urljoin(worker_address, self.worker_endpoint) url = urljoin(worker_address, self.worker_endpoint)
self.status = ClientStatus.Generating self.status = ClientStatus.Generating
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
if response.status_code != 200: if response.status_code != 200:
self.infer_error.append( self.infer_error.append(
+15
View File
@@ -0,0 +1,15 @@
import tempfile
from functools import cache
import requests
@cache
def get_cert_file_path():
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
response = requests.get(cert_url)
response.raise_for_status()
# Use a temporary file that is not deleted on close
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
f.write(response.content)
return f.name
+3
View File
@@ -5,6 +5,7 @@ import requests
from lib.test_utils import print_truncate_res from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
""" """
NOTE: this client example uses a custom comfy workflow compatible with SD3 only NOTE: this client example uses a custom comfy workflow compatible with SD3 only
@@ -51,6 +52,7 @@ def call_default_workflow(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
+26 -17
View File
@@ -6,6 +6,7 @@ from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List from typing import Dict, Any, Optional, Iterator, Union, List
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig from .data_types.client import CompletionConfig, ChatCompletionConfig
logging.basicConfig( logging.basicConfig(
@@ -28,24 +29,16 @@ class APIClient:
DEFAULT_TIMEOUT = 4 DEFAULT_TIMEOUT = 4
def __init__( def __init__(
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
): ):
self.endpoint_group_name = endpoint_group_name self.endpoint_group_name = endpoint_group_name
self.api_key = api_key self.api_key = api_key
self.server_url = server_url self.server_url = server_url
self.instance = instance self.endpoint_api_key = endpoint_api_key
self.endpoint_api_key = self._get_endpoint_api_key()
def _get_endpoint_api_key(self) -> Optional[str]:
"""Get the endpoint API key"""
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
instance=self.instance,
)
if not endpoint_api_key:
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
return endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service""" """Get worker URL and auth data from routing service"""
@@ -98,9 +91,13 @@ class APIClient:
# Make the request using the specified method # Make the request using the specified method
if method.upper() == "POST": if method.upper() == "POST":
response = requests.post(url, json=req_data, stream=stream) response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET": elif method.upper() == "GET":
response = requests.get(url, params=req_data, stream=stream) response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else: else:
raise ValueError(f"Unsupported HTTP method: {method}") raise ValueError(f"Unsupported HTTP method: {method}")
@@ -554,12 +551,24 @@ def main():
sys.exit(1) sys.exit(1)
try: try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client # Create the core API client
client = APIClient( client = APIClient(
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url, server_url=args.server_url,
instance=args.instance, endpoint_api_key=endpoint_api_key,
) )
# Create tool manager and demo (passing the model parameter) # Create tool manager and demo (passing the model parameter)
+7 -1
View File
@@ -124,7 +124,12 @@ class CompletionsData(GenericData):
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
test_input = {"model": model, "prompt": prompt, "temperature": 0.7} test_input = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input) return cls(input=test_input)
@@ -158,6 +163,7 @@ class ChatCompletionsData(GenericData):
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 500,
} }
return cls(input=test_input) return cls(input=test_input)
+6 -1
View File
@@ -4,6 +4,7 @@ import json
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
req_data = dict(payload=payload, auth_data=auth_data) req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT) url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}") print(f"url: {url}")
response = requests.post(url, json=req_data) response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status() response.raise_for_status()
res = response.json() res = response.json()
print(res) print(res)