From 0bf3247a3426fa4cec491dfd5ab85f7375c97e9d Mon Sep 17 00:00:00 2001 From: Nader Arbabian Date: Mon, 11 Aug 2025 12:37:53 -0700 Subject: [PATCH] fix completions and interactive client --- utils/endpoint_util.py | 27 ++++++++++++++++++++------- workers/openai/client.py | 2 +- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/utils/endpoint_util.py b/utils/endpoint_util.py index 30cf074..37930af 100644 --- a/utils/endpoint_util.py +++ b/utils/endpoint_util.py @@ -16,6 +16,24 @@ class Endpoint: Utility class for handling endpoint operations. """ + @staticmethod + def get_autoscaler_server_url(instance: str) -> str: + endpoints = { + "alpha": "run-alpha", + "candidate": "run-candidate", + "prod": "run", + } + return f"https://{endpoints[instance]}.vast.ai/" + + @staticmethod + def get_server_url(instance: str) -> str: + endpoints = { + "alpha": "alpha", + "candidate": "candidate", + "prod": "console", + } + return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/" + @staticmethod def get_endpoint_api_key( endpoint_name: str, account_api_key: str, instance: str @@ -30,18 +48,13 @@ class Endpoint: Returns: Endpoint API key if successful, None otherwise """ - endpoints = { - "alpha": "alpha", - "candidate": "candidate", - "prod": "console", - } - vast_console_url = f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/" headers = {"Authorization": f"Bearer {account_api_key}"} try: log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") response = requests.get( - f"{vast_console_url}?autoscaler_instance={instance}", headers=headers + f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}", + headers=headers, ) if response.status_code != 200: diff --git a/workers/openai/client.py b/workers/openai/client.py index 79122a1..e34cc90 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -567,7 +567,7 @@ def main(): client = APIClient( endpoint_group_name=args.endpoint_group_name, api_key=args.api_key, - server_url=args.server_url, + server_url=Endpoint.get_autoscaler_server_url(args.instance), endpoint_api_key=endpoint_api_key, )