From e0be45f39ac5e7732b4c929b0a4f9afdbefd6e87 Mon Sep 17 00:00:00 2001 From: Rob Ballantyne Date: Fri, 18 Jul 2025 01:09:23 +0100 Subject: [PATCH] Addresses breaking change in core pyworker (#22) * Addresses breaking change in test_utils.py Endpoint.get_endpoint_api_key() now requires instance Moves the call to this function out of the APIClient and into main * Ensure make_benchmark_payload has a value to calculate the workload --------- Co-authored-by: Nader Arbabian --- lib/server.py | 2 +- workers/openai/client.py | 34 ++++++++++++++++------------- workers/openai/data_types/server.py | 8 ++++++- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/lib/server.py b/lib/server.py index 80e2959..b21c880 100644 --- a/lib/server.py +++ b/lib/server.py @@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): log.debug("starting server...") app = web.Application() app.add_routes(routes) - runner = web.AppRunner(app, handler_cancellation=True) + runner = web.AppRunner(app) await runner.setup() site = web.TCPSite( runner, diff --git a/workers/openai/client.py b/workers/openai/client.py index 4dbf099..af4f510 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -28,24 +28,16 @@ class APIClient: DEFAULT_TIMEOUT = 4 def __init__( - self, endpoint_group_name: str, api_key: str, server_url: str, instance: str + self, + endpoint_group_name: str, + api_key: str, + server_url: str, + endpoint_api_key: str, ): self.endpoint_group_name = endpoint_group_name self.api_key = api_key self.server_url = server_url - self.instance = instance - self.endpoint_api_key = self._get_endpoint_api_key() - - def _get_endpoint_api_key(self) -> Optional[str]: - """Get the endpoint API key""" - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=self.endpoint_group_name, - account_api_key=self.api_key, - instance=self.instance, - ) - if not endpoint_api_key: - log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}") - return endpoint_api_key + self.endpoint_api_key = endpoint_api_key def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: """Get worker URL and auth data from routing service""" @@ -554,12 +546,24 @@ def main(): sys.exit(1) try: + endpoint_api_key = Endpoint.get_endpoint_api_key( + endpoint_name=args.endpoint_group_name, + account_api_key=args.api_key, + instance=args.instance, + ) + + if not endpoint_api_key: + log.error( + f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting." + ) + sys.exit(1) + # Create the core API client client = APIClient( endpoint_group_name=args.endpoint_group_name, api_key=args.api_key, server_url=args.server_url, - instance=args.instance, + endpoint_api_key=endpoint_api_key, ) # Create tool manager and demo (passing the model parameter) diff --git a/workers/openai/data_types/server.py b/workers/openai/data_types/server.py index dd9b45c..92f204b 100644 --- a/workers/openai/data_types/server.py +++ b/workers/openai/data_types/server.py @@ -124,7 +124,12 @@ class CompletionsData(GenericData): if not model: raise ValueError("MODEL_NAME environment variable not set") - test_input = {"model": model, "prompt": prompt, "temperature": 0.7} + test_input = { + "model": model, + "prompt": prompt, + "temperature": 0.7, + "max_tokens": 500, + } return cls(input=test_input) @@ -158,6 +163,7 @@ class ChatCompletionsData(GenericData): "model": model, "messages": [{"role": "user", "content": prompt}], "temperature": 0.7, + "max_tokens": 500, } return cls(input=test_input)