download vast.ai's root certificate in order to make pyworker requests
This commit is contained in:
@@ -10,6 +10,7 @@ from collections import Counter
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from lib.data_types import AuthData, ApiPayload
|
from lib.data_types import AuthData, ApiPayload
|
||||||
@@ -120,9 +121,11 @@ class ClientState:
|
|||||||
self.url = worker_address
|
self.url = worker_address
|
||||||
url = urljoin(worker_address, self.worker_endpoint)
|
url = urljoin(worker_address, self.worker_endpoint)
|
||||||
self.status = ClientStatus.Generating
|
self.status = ClientStatus.Generating
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
self.infer_error.append(
|
self.infer_error.append(
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
import tempfile
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_cert_file_path():
|
||||||
|
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||||
|
response = requests.get(cert_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
# Use a temporary file that is not deleted on close
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
return f.name
|
||||||
@@ -5,6 +5,7 @@ import requests
|
|||||||
|
|
||||||
from lib.test_utils import print_truncate_res
|
from lib.test_utils import print_truncate_res
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||||
@@ -51,6 +52,7 @@ def call_default_workflow(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from urllib.parse import urljoin
|
|||||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -90,9 +91,13 @@ class APIClient:
|
|||||||
|
|
||||||
# Make the request using the specified method
|
# Make the request using the specified method
|
||||||
if method.upper() == "POST":
|
if method.upper() == "POST":
|
||||||
response = requests.post(url, json=req_data, stream=stream)
|
response = requests.post(
|
||||||
|
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
elif method.upper() == "GET":
|
elif method.upper() == "GET":
|
||||||
response = requests.get(url, params=req_data, stream=stream)
|
response = requests.get(
|
||||||
|
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
|||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
print(f"url: {url}")
|
print(f"url: {url}")
|
||||||
response = requests.post(url, json=req_data)
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
res = response.json()
|
res = response.json()
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
Reference in New Issue
Block a user