Addresses breaking change in core pyworker (#22)

* Addresses breaking change in test_utils.py

Endpoint.get_endpoint_api_key() now requires instance

Moves the call to this function out of the APIClient and into main

* Ensure make_benchmark_payload has a value to calculate the workload

---------

Co-authored-by: Nader Arbabian <nader@vast.ai>
This commit is contained in:
Rob Ballantyne
2025-07-18 01:09:23 +01:00
committed by Nader Arbabian
parent be2aafdb1f
commit e0be45f39a
3 changed files with 27 additions and 17 deletions
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app, handler_cancellation=True) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
+19 -15
View File
@@ -28,24 +28,16 @@ class APIClient:
DEFAULT_TIMEOUT = 4 DEFAULT_TIMEOUT = 4
def __init__( def __init__(
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
): ):
self.endpoint_group_name = endpoint_group_name self.endpoint_group_name = endpoint_group_name
self.api_key = api_key self.api_key = api_key
self.server_url = server_url self.server_url = server_url
self.instance = instance self.endpoint_api_key = endpoint_api_key
self.endpoint_api_key = self._get_endpoint_api_key()
def _get_endpoint_api_key(self) -> Optional[str]:
"""Get the endpoint API key"""
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
instance=self.instance,
)
if not endpoint_api_key:
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
return endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service""" """Get worker URL and auth data from routing service"""
@@ -554,12 +546,24 @@ def main():
sys.exit(1) sys.exit(1)
try: try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client # Create the core API client
client = APIClient( client = APIClient(
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url, server_url=args.server_url,
instance=args.instance, endpoint_api_key=endpoint_api_key,
) )
# Create tool manager and demo (passing the model parameter) # Create tool manager and demo (passing the model parameter)
+7 -1
View File
@@ -124,7 +124,12 @@ class CompletionsData(GenericData):
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
test_input = {"model": model, "prompt": prompt, "temperature": 0.7} test_input = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input) return cls(input=test_input)
@@ -158,6 +163,7 @@ class ChatCompletionsData(GenericData):
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 500,
} }
return cls(input=test_input) return cls(input=test_input)