diff --git a/lib/test_utils.py b/lib/test_utils.py index ba97611..8635027 100644 --- a/lib/test_utils.py +++ b/lib/test_utils.py @@ -10,6 +10,7 @@ from collections import Counter from dataclasses import dataclass, field, asdict from urllib.parse import urljoin from utils.endpoint_util import Endpoint +from utils.ssl import get_cert_file_path import requests from lib.data_types import AuthData, ApiPayload @@ -120,9 +121,11 @@ class ClientState: self.url = worker_address url = urljoin(worker_address, self.worker_endpoint) self.status = ClientStatus.Generating + response = requests.post( url, json=req_data, + verify=get_cert_file_path(), ) if response.status_code != 200: self.infer_error.append( diff --git a/utils/ssl.py b/utils/ssl.py new file mode 100644 index 0000000..5406ac8 --- /dev/null +++ b/utils/ssl.py @@ -0,0 +1,15 @@ +import tempfile +from functools import cache + +import requests + + +@cache +def get_cert_file_path(): + cert_url = "https://console.vast.ai/static/jvastai_root.cer" + response = requests.get(cert_url) + response.raise_for_status() + # Use a temporary file that is not deleted on close + with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f: + f.write(response.content) + return f.name diff --git a/workers/comfyui/client.py b/workers/comfyui/client.py index 6563e00..771371e 100644 --- a/workers/comfyui/client.py +++ b/workers/comfyui/client.py @@ -5,6 +5,7 @@ import requests from lib.test_utils import print_truncate_res from utils.endpoint_util import Endpoint +from utils.ssl import get_cert_file_path """ NOTE: this client example uses a custom comfy workflow compatible with SD3 only @@ -51,6 +52,7 @@ def call_default_workflow( response = requests.post( url, json=req_data, + verify=get_cert_file_path(), ) response.raise_for_status() print_truncate_res(str(response.json())) @@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3( response = requests.post( url, json=req_data, + verify=get_cert_file_path(), ) response.raise_for_status() print_truncate_res(str(response.json())) diff --git a/workers/openai/client.py b/workers/openai/client.py index af4f510..79122a1 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -6,6 +6,7 @@ from urllib.parse import urljoin from typing import Dict, Any, Optional, Iterator, Union, List import requests from utils.endpoint_util import Endpoint +from utils.ssl import get_cert_file_path from .data_types.client import CompletionConfig, ChatCompletionConfig logging.basicConfig( @@ -90,9 +91,13 @@ class APIClient: # Make the request using the specified method if method.upper() == "POST": - response = requests.post(url, json=req_data, stream=stream) + response = requests.post( + url, json=req_data, stream=stream, verify=get_cert_file_path() + ) elif method.upper() == "GET": - response = requests.get(url, params=req_data, stream=stream) + response = requests.get( + url, params=req_data, stream=stream, verify=get_cert_file_path() + ) else: raise ValueError(f"Unsupported HTTP method: {method}") diff --git a/workers/tgi/client.py b/workers/tgi/client.py index 7e4f1bb..66dacb9 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -4,6 +4,7 @@ import json from urllib.parse import urljoin import requests from utils.endpoint_util import Endpoint +from utils.ssl import get_cert_file_path logging.basicConfig( level=logging.DEBUG, @@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No req_data = dict(payload=payload, auth_data=auth_data) url = urljoin(url, WORKER_ENDPOINT) print(f"url: {url}") - response = requests.post(url, json=req_data) + response = requests.post( + url, + json=req_data, + verify=get_cert_file_path(), + ) response.raise_for_status() res = response.json() print(res)