diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index 3d28e03..acdb1f3 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -2,6 +2,7 @@ import logging import uuid import random from urllib.parse import urljoin +import json import requests @@ -21,6 +22,41 @@ def call_text2image_workflow( endpoint_group_name: str, api_key: str, server_url: str ) -> None: """Simple Text2Image using the new modifier-based approach""" + + def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"): + """Helper function for making requests with consistent error handling""" + try: + response = requests.post( + url, + json=payload, + timeout=timeout, + verify=verify + ) + response.raise_for_status() + return response.json() + + except requests.exceptions.HTTPError as http_err: + log.error(f"HTTP error occurred during {context}: {http_err}") + log.error(f"Status Code: {response.status_code}") + log.error("Response content:", response.text) + return None + except requests.exceptions.Timeout: + log.error(f"Timeout occurred during {context}: {url}") + return None + except requests.exceptions.ConnectionError: + log.error(f"Connection error occurred during {context}: {url}") + return None + except json.JSONDecodeError as json_err: + log.error(f"Failed to decode JSON response during {context}: {json_err}") + if 'response' in locals(): + print("Response content:", response.text) + return None + except Exception as err: + log.error(f"An unexpected error occurred during {context}: {err}") + if 'response' in locals(): + log.error("Response content (if available):", response.text) + return None + WORKER_ENDPOINT = "/generate/sync" COST = 100 @@ -30,24 +66,30 @@ def call_text2image_workflow( "api_key": api_key, "cost": COST, } - response = requests.post( - urljoin(server_url, "/route/"), - json=route_payload, + + # First request - get routing information + route_response = make_request( + url=urljoin(server_url, "/route/"), + payload=route_payload, timeout=4, - ) - response.raise_for_status() - message = response.json() - url = message["url"] - auth_data = dict( - signature=message["signature"], - cost=message["cost"], - endpoint=message["endpoint"], - reqnum=message["reqnum"], - url=message["url"], + context="route request" ) - # Build the new payload structure - payload = { + if route_response is None: + return None + + # Extract data from route response + url = route_response["url"] + auth_data = dict( + signature=route_response["signature"], + cost=route_response["cost"], + endpoint=route_response["endpoint"], + reqnum=route_response["reqnum"], + url=route_response["url"], + ) + + # Build the payload for the worker request + worker_payload = { "input": { "request_id": str(uuid.uuid4()), "modifier": "Text2Image", @@ -63,17 +105,19 @@ def call_text2image_workflow( "expected_time": 30.0 # Expected 30 seconds on RTX4090 } - req_data = dict(payload=payload, auth_data=auth_data) - url = urljoin(url, WORKER_ENDPOINT) - print(f"url: {url}") + req_data = dict(payload=worker_payload, auth_data=auth_data) + worker_url = urljoin(url, WORKER_ENDPOINT) + print(f"url: {worker_url}") - response = requests.post( - url, - json=req_data, + # Second request - call the worker endpoint + worker_response = make_request( + url=worker_url, + payload=req_data, verify=get_cert_file_path(), + context="worker request" ) - response.raise_for_status() - print(str(response.json())) + + return worker_response if __name__ == "__main__": @@ -85,14 +129,16 @@ if __name__ == "__main__": account_api_key=args.api_key, instance=args.instance, ) + if endpoint_api_key: - try: - call_text2image_workflow( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - except Exception as e: - log.error(f"Error during API call: {e}") + result = call_text2image_workflow( + api_key=endpoint_api_key, + endpoint_group_name=args.endpoint_group_name, + server_url=args.server_url, + ) + if result is None: + log.error("Text2Image workflow failed") + else: + print(result) else: log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")