diff --git a/lib/data_types.py b/lib/data_types.py index d948c60..f051cd3 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -66,7 +66,7 @@ class AuthData: """data used to authenticate requester""" cost: str - endpoint: str + endpoint_id: int reqnum: int request_idx: int signature: str diff --git a/lib/test_utils.py b/lib/test_utils.py index d64a4b6..a74c605 100644 --- a/lib/test_utils.py +++ b/lib/test_utils.py @@ -75,6 +75,7 @@ def print_truncate_res(res: str): @dataclass class ClientState: endpoint_group_name: str + endpoint_id: int api_key: str server_url: str worker_endpoint: str @@ -95,7 +96,7 @@ class ClientState: self.status = ClientStatus.Error return route_payload = { - "endpoint": self.endpoint_group_name, + "endpoint_id": self.endpoint_id, "api_key": self.api_key, "cost": self.payload.count_workload(), } @@ -244,16 +245,19 @@ def run_test( print_thread = threading.Thread(target=print_state, args=(clients, num_requests)) print_thread.daemon = True # makes threads get killed on program exit print_thread.start() - endpoint_api_key = Endpoint.get_endpoint_api_key( + endpoint_info = Endpoint.get_endpoint_info( endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance ) - if not endpoint_api_key: + if not endpoint_info: log.debug(f"Endpoint {endpoint_group_name} not found for API key") return + endpoint_id = endpoint_info["id"] + endpoint_api_key = endpoint_info["api_key"] try: for _ in range(num_requests): client = ClientState( endpoint_group_name=endpoint_group_name, + endpoint_id=endpoint_id, api_key=endpoint_api_key, server_url=server_url, worker_endpoint=worker_endpoint, diff --git a/workers/comfyui/client.py b/workers/comfyui/client.py index 7d1935e..aa9e1da 100644 --- a/workers/comfyui/client.py +++ b/workers/comfyui/client.py @@ -13,11 +13,11 @@ from vastai import Serverless ENDPOINT_NAME = "my-comfyui-endpoint" COST = 100 # Use a constant cost for image generation -def call_default_workflow(client: Serverless) -> None: +def call_default_workflow(endpoint_id: int, api_key: str, server_url: str) -> None: WORKER_ENDPOINT = "/prompt" COST = 100 route_payload = { - "endpoint": endpoint_group_name, + "endpoint_id": endpoint_id, "api_key": api_key, "cost": COST, } @@ -32,7 +32,7 @@ def call_default_workflow(client: Serverless) -> None: auth_data = dict( signature=message["signature"], cost=message["cost"], - endpoint=message["endpoint"], + endpoint_id=message["endpoint_id"], reqnum=message["reqnum"], url=message["url"], ) @@ -52,12 +52,12 @@ def call_default_workflow(client: Serverless) -> None: def call_custom_workflow_for_sd3( - endpoint_group_name: str, api_key: str, server_url: str + endpoint_id: int, api_key: str, server_url: str ) -> None: WORKER_ENDPOINT = "/custom-workflow" COST = 100 route_payload = { - "endpoint": endpoint_group_name, + "endpoint_id": endpoint_id, "api_key": api_key, "cost": COST, } @@ -72,7 +72,7 @@ def call_custom_workflow_for_sd3( auth_data = dict( signature=message["signature"], cost=message["cost"], - endpoint=message["endpoint"], + endpoint_id=message["endpoint_id"], reqnum=message["reqnum"], url=message["url"], request_idx=message["request_idx"], @@ -146,25 +146,28 @@ def call_custom_workflow_for_sd3( if __name__ == "__main__": from lib.test_utils import test_args + log = logging.getLogger(__name__) args = test_args.parse_args() - endpoint_api_key = Endpoint.get_endpoint_api_key( + endpoint_info = Endpoint.get_endpoint_info( endpoint_name=args.endpoint_group_name, account_api_key=args.api_key, instance=args.instance, ) - if endpoint_api_key: + if endpoint_info: + endpoint_id = endpoint_info["id"] + endpoint_api_key = endpoint_info["api_key"] try: call_default_workflow( + endpoint_id=endpoint_id, api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, server_url=args.server_url, ) call_custom_workflow_for_sd3( + endpoint_id=endpoint_id, api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, server_url=args.server_url, ) except Exception as e: log.error(f"Error during API call: {e}") else: - log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ") + log.error(f"Failed to get endpoint info for {args.endpoint_group_name}") diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py index 9cb5f37..7b1f090 100644 --- a/workers/openai/test_load.py +++ b/workers/openai/test_load.py @@ -60,7 +60,7 @@ def do_one(endpoint_name: str, worker_session): try: workload = payload.count_workload() - route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload} + route_payload = {"endpoint_id": endpoint_id, "api_key": endpoint_api_key, "cost": workload} headers = {"Authorization": f"Bearer {endpoint_api_key}"} start = time.time() r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)