Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ac51947b4 |
+1
-1
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
log.debug("starting server...")
|
||||
app = web.Application()
|
||||
app.add_routes(routes)
|
||||
runner = web.AppRunner(app)
|
||||
runner = web.AppRunner(app, handler_cancellation=True)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(
|
||||
runner,
|
||||
|
||||
@@ -10,7 +10,6 @@ from collections import Counter
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from urllib.parse import urljoin
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
import requests
|
||||
|
||||
from lib.data_types import AuthData, ApiPayload
|
||||
@@ -121,11 +120,9 @@ class ClientState:
|
||||
self.url = worker_address
|
||||
url = urljoin(worker_address, self.worker_endpoint)
|
||||
self.status = ClientStatus.Generating
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.infer_error.append(
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import tempfile
|
||||
from functools import cache
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@cache
|
||||
def get_cert_file_path():
|
||||
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||
response = requests.get(cert_url)
|
||||
response.raise_for_status()
|
||||
# Use a temporary file that is not deleted on close
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||
f.write(response.content)
|
||||
return f.name
|
||||
@@ -5,7 +5,6 @@ import requests
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
"""
|
||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||
@@ -52,7 +51,6 @@ def call_default_workflow(
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
@@ -143,7 +141,6 @@ def call_custom_workflow_for_sd3(
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
|
||||
+17
-26
@@ -6,7 +6,6 @@ from urllib.parse import urljoin
|
||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -29,16 +28,24 @@ class APIClient:
|
||||
DEFAULT_TIMEOUT = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_group_name: str,
|
||||
api_key: str,
|
||||
server_url: str,
|
||||
endpoint_api_key: str,
|
||||
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str
|
||||
):
|
||||
self.endpoint_group_name = endpoint_group_name
|
||||
self.api_key = api_key
|
||||
self.server_url = server_url
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
self.instance = instance
|
||||
self.endpoint_api_key = self._get_endpoint_api_key()
|
||||
|
||||
def _get_endpoint_api_key(self) -> Optional[str]:
|
||||
"""Get the endpoint API key"""
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=self.endpoint_group_name,
|
||||
account_api_key=self.api_key,
|
||||
instance=self.instance,
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
|
||||
return endpoint_api_key
|
||||
|
||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||
"""Get worker URL and auth data from routing service"""
|
||||
@@ -91,13 +98,9 @@ class APIClient:
|
||||
|
||||
# Make the request using the specified method
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(
|
||||
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
response = requests.post(url, json=req_data, stream=stream)
|
||||
elif method.upper() == "GET":
|
||||
response = requests.get(
|
||||
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
response = requests.get(url, params=req_data, stream=stream)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
@@ -551,24 +554,12 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
if not endpoint_api_key:
|
||||
log.error(
|
||||
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create the core API client
|
||||
client = APIClient(
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
api_key=args.api_key,
|
||||
server_url=args.server_url,
|
||||
endpoint_api_key=endpoint_api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
# Create tool manager and demo (passing the model parameter)
|
||||
|
||||
@@ -124,12 +124,7 @@ class CompletionsData(GenericData):
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
test_input = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
test_input = {"model": model, "prompt": prompt, "temperature": 0.7}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@@ -163,7 +158,6 @@ class ChatCompletionsData(GenericData):
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
@@ -43,11 +42,7 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response = requests.post(url, json=req_data)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
print(res)
|
||||
|
||||
Reference in New Issue
Block a user