Endpoint update pr one (#1)
* Added endpoint flexibility along with existing log. extended the log support * Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support * Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Added endpoint flexibility along with existing log. extended the log support Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Endpoint Utils and API changes
This commit is contained in:
committed by
Nader Arbabian
parent
b1ca68c349
commit
71ed54ebe4
@@ -1,3 +1,5 @@
|
|||||||
.direnv
|
.direnv
|
||||||
.envrc
|
.envrc
|
||||||
__pycache__
|
__pycache__
|
||||||
|
bin/
|
||||||
|
lib64
|
||||||
+7
-7
@@ -186,7 +186,6 @@ class Backend:
|
|||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
|
|
||||||
async def __healthcheck(self):
|
async def __healthcheck(self):
|
||||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||||
if health_check_url is None:
|
if health_check_url is None:
|
||||||
@@ -200,19 +199,20 @@ class Backend:
|
|||||||
log.debug("Healthcheck successful")
|
log.debug("Healthcheck successful")
|
||||||
elif response.status == 503:
|
elif response.status == 503:
|
||||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||||
self.backend_errored(f"Healthcheck failed with status: {response.status}")
|
self.backend_errored(
|
||||||
|
f"Healthcheck failed with status: {response.status}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# endpoint not ready yet so bail
|
# endpoint not ready yet so bail
|
||||||
log.debug(
|
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||||
f"Healthcheck Endpoint not ready: {response.status}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Healthcheck failed with exception: {e}")
|
log.debug(f"Healthcheck failed with exception: {e}")
|
||||||
self.backend_errored(str(e))
|
self.backend_errored(str(e))
|
||||||
|
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck())
|
await gather(
|
||||||
|
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
|
||||||
|
)
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
self.metrics._model_errored(msg)
|
self.metrics._model_errored(msg)
|
||||||
|
|||||||
+6
-4
@@ -8,6 +8,7 @@ from aiohttp import web, ClientResponse
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -22,8 +23,10 @@ class JsonDataException(Exception):
|
|||||||
def __init__(self, json_msg: Dict[str, Any]):
|
def __init__(self, json_msg: Dict[str, Any]):
|
||||||
self.message = json_msg
|
self.message = json_msg
|
||||||
|
|
||||||
|
|
||||||
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
|
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ApiPayload(ABC):
|
class ApiPayload(ABC):
|
||||||
|
|
||||||
@@ -45,7 +48,9 @@ class ApiPayload(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T:
|
def from_json_msg(
|
||||||
|
cls: Type[ApiPayload_T], json_msg: Dict[str, Any]
|
||||||
|
) -> ApiPayload_T:
|
||||||
"""
|
"""
|
||||||
defines how to create an API payload from a JSON message,
|
defines how to create an API payload from a JSON message,
|
||||||
it should throw an JsonDataException if there are issues with some fields
|
it should throw an JsonDataException if there are issues with some fields
|
||||||
@@ -84,8 +89,6 @@ class AuthData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
||||||
"""
|
"""
|
||||||
@@ -108,7 +111,6 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
|||||||
"""the endpoint on the model API that is used for healthchecks"""
|
"""the endpoint on the model API that is used for healthchecks"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def payload_cls(cls) -> Type[ApiPayload_T]:
|
def payload_cls(cls) -> Type[ApiPayload_T]:
|
||||||
|
|||||||
+27
-2
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
@@ -13,6 +14,13 @@ import requests
|
|||||||
|
|
||||||
from lib.data_types import AuthData, ApiPayload
|
from lib.data_types import AuthData, ApiPayload
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class ClientStatus(Enum):
|
class ClientStatus(Enum):
|
||||||
FetchEndpoint = 1
|
FetchEndpoint = 1
|
||||||
@@ -71,9 +79,19 @@ class ClientState:
|
|||||||
|
|
||||||
def make_call(self):
|
def make_call(self):
|
||||||
self.status = ClientStatus.FetchEndpoint
|
self.status = ClientStatus.FetchEndpoint
|
||||||
|
endpoint_api_key = AuthData.get_endpoint_api_key(
|
||||||
|
endpoint_name=self.endpoint_group_name,
|
||||||
|
account_api_key=self.api_key,
|
||||||
|
)
|
||||||
|
if not endpoint_api_key:
|
||||||
|
self.as_error.append(
|
||||||
|
f"Endpoint {self.endpoint_group_name} not found for API key",
|
||||||
|
)
|
||||||
|
self.status = ClientStatus.Error
|
||||||
|
return
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint": self.endpoint_group_name,
|
"endpoint": self.endpoint_group_name,
|
||||||
"api_key": self.api_key,
|
"api_key": endpoint_api_key,
|
||||||
"cost": self.payload.count_workload(),
|
"cost": self.payload.count_workload(),
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@@ -215,11 +233,18 @@ def run_test(
|
|||||||
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
||||||
print_thread.daemon = True # makes threads get killed on program exit
|
print_thread.daemon = True # makes threads get killed on program exit
|
||||||
print_thread.start()
|
print_thread.start()
|
||||||
|
endpoint_api_key = AuthData.get_endpoint_api_key(
|
||||||
|
endpoint_name=endpoint_group_name,
|
||||||
|
account_api_key=api_key,
|
||||||
|
)
|
||||||
|
if not endpoint_api_key:
|
||||||
|
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
for _ in range(num_requests):
|
for _ in range(num_requests):
|
||||||
client = ClientState(
|
client = ClientState(
|
||||||
endpoint_group_name=endpoint_group_name,
|
endpoint_group_name=endpoint_group_name,
|
||||||
api_key=api_key,
|
api_key=endpoint_api_key,
|
||||||
server_url=server_url,
|
server_url=server_url,
|
||||||
worker_endpoint=worker_endpoint,
|
worker_endpoint=worker_endpoint,
|
||||||
payload=payload_cls.for_test(),
|
payload=payload_cls.for_test(),
|
||||||
|
|||||||
+2
-3
@@ -8,11 +8,10 @@ SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
|||||||
ENV_PATH="$WORKSPACE_DIR/worker-env"
|
ENV_PATH="$WORKSPACE_DIR/worker-env"
|
||||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
MODEL_LOG="$WORKSPACE_DIR/model.log"
|
||||||
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||||
USE_SSL="${USE_SSL:-true}"
|
USE_SSL="${USE_SSL:-true}"
|
||||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||||
|
|
||||||
mkdir -p "$WORKSPACE_DIR"
|
mkdir -p "$WORKSPACE_DIR"
|
||||||
cd "$WORKSPACE_DIR"
|
cd "$WORKSPACE_DIR"
|
||||||
|
|
||||||
@@ -47,7 +46,7 @@ env | grep _ >> /etc/environment;
|
|||||||
|
|
||||||
if [ ! -d "$ENV_PATH" ]
|
if [ ! -d "$ENV_PATH" ]
|
||||||
then
|
then
|
||||||
apt install -y python3.10-venv
|
apt install -y python3.12-venv
|
||||||
echo "setting up venv"
|
echo "setting up venv"
|
||||||
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
from logging import log
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class Endpoint:
|
||||||
|
"""
|
||||||
|
Utility class for handling endpoint operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Fetch endpoint API key from VastAI console following the healthcheck pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint_name: Name of the endpoint
|
||||||
|
account_api_key: Account API key for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Endpoint API key if successful, None otherwise
|
||||||
|
"""
|
||||||
|
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
|
||||||
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||||
|
response = requests.get(vast_console_url, headers=headers)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
|
||||||
|
log.debug(error_msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except requests.exceptions.JSONDecodeError as e:
|
||||||
|
log.debug(f"Failed to parse JSON response: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = data.get("results", [])
|
||||||
|
|
||||||
|
endpoint: Optional[Dict[str, Any]] = next(
|
||||||
|
(item for item in result if item["endpoint_name"] == endpoint_name),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not endpoint:
|
||||||
|
error_msg = f"Endpoint '{endpoint_name}' not found."
|
||||||
|
log.debug(error_msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
endpoint_api_key = endpoint.get("api_key")
|
||||||
|
if not endpoint_api_key:
|
||||||
|
error_msg = f"API key for endpoint '{endpoint_name}' not found."
|
||||||
|
log.debug(error_msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
log.debug(f"Successfully retrieved API key for endpoint: {endpoint_name}")
|
||||||
|
return endpoint_api_key
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
error_msg = f"Request error while fetching endpoint API key: {e}"
|
||||||
|
log.debug(error_msg)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error while fetching endpoint API key: {e}"
|
||||||
|
log.debug(error_msg)
|
||||||
|
return None
|
||||||
@@ -1,12 +1,20 @@
|
|||||||
|
import logging
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from lib.test_utils import print_truncate_res
|
from lib.test_utils import print_truncate_res
|
||||||
|
from utils.endpoint_util import Endpoint
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||||
"""
|
"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def call_default_workflow(
|
def call_default_workflow(
|
||||||
@@ -24,6 +32,7 @@ def call_default_workflow(
|
|||||||
json=route_payload,
|
json=route_payload,
|
||||||
timeout=4,
|
timeout=4,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
message = response.json()
|
message = response.json()
|
||||||
url = message["url"]
|
url = message["url"]
|
||||||
auth_data = dict(
|
auth_data = dict(
|
||||||
@@ -43,6 +52,7 @@ def call_default_workflow(
|
|||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +71,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
json=route_payload,
|
json=route_payload,
|
||||||
timeout=4,
|
timeout=4,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
message = response.json()
|
message = response.json()
|
||||||
url = message["url"]
|
url = message["url"]
|
||||||
auth_data = dict(
|
auth_data = dict(
|
||||||
@@ -131,6 +142,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
|
|
||||||
|
|
||||||
@@ -138,13 +150,23 @@ if __name__ == "__main__":
|
|||||||
from lib.test_utils import test_args
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
args = test_args.parse_args()
|
args = test_args.parse_args()
|
||||||
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
)
|
||||||
|
if endpoint_api_key:
|
||||||
|
try:
|
||||||
call_default_workflow(
|
call_default_workflow(
|
||||||
api_key=args.api_key,
|
api_key=endpoint_api_key,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
call_custom_workflow_for_sd3(
|
call_custom_workflow_for_sd3(
|
||||||
api_key=args.api_key,
|
api_key=endpoint_api_key,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error during API call: {e}")
|
||||||
|
else:
|
||||||
|
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||||
|
|||||||
+37
-9
@@ -1,8 +1,16 @@
|
|||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from utils.endpoint_util import Endpoint
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||||
@@ -18,28 +26,31 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
|||||||
json=route_payload,
|
json=route_payload,
|
||||||
timeout=4,
|
timeout=4,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
message = response.json()
|
message = response.json()
|
||||||
url = message["url"]
|
url = message["url"]
|
||||||
|
|
||||||
auth_data = dict(
|
auth_data = dict(
|
||||||
signature=message["signature"],
|
signature=message["signature"],
|
||||||
cost=message["cost"],
|
cost=message["cost"],
|
||||||
endpoint=message["endpoint"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=url,
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
print(f"url: {url}")
|
print(f"url: {url}")
|
||||||
response = requests.post(
|
response = requests.post(url, json=req_data)
|
||||||
url,
|
response.raise_for_status()
|
||||||
json=req_data,
|
|
||||||
)
|
|
||||||
res = response.json()
|
res = response.json()
|
||||||
print(res)
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str):
|
def call_generate_stream(
|
||||||
|
endpoint_group_name: str, api_key: str, server_url: str
|
||||||
|
) -> None:
|
||||||
WORKER_ENDPOINT = "/generate_stream"
|
WORKER_ENDPOINT = "/generate_stream"
|
||||||
COST = 100
|
COST = 100
|
||||||
route_payload = {
|
route_payload = {
|
||||||
@@ -52,6 +63,7 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
|
|||||||
json=route_payload,
|
json=route_payload,
|
||||||
timeout=4,
|
timeout=4,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
message = response.json()
|
message = response.json()
|
||||||
url = message["url"]
|
url = message["url"]
|
||||||
print(f"url: {url}")
|
print(f"url: {url}")
|
||||||
@@ -66,12 +78,17 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
|
|||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
response = requests.post(url, json=req_data, stream=True)
|
response = requests.post(url, json=req_data, stream=True)
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
payload = line.decode().lstrip("data:").rstrip()
|
payload = line.decode().lstrip("data:").rstrip()
|
||||||
if payload:
|
if payload:
|
||||||
|
try:
|
||||||
data = json.loads(payload)
|
data = json.loads(payload)
|
||||||
print(data["token"]["text"], end="")
|
print(data["token"]["text"], end="")
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
log.warning(f"Failed to parse streaming response: {e}")
|
||||||
|
continue
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
@@ -79,13 +96,24 @@ if __name__ == "__main__":
|
|||||||
from lib.test_utils import test_args
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
args = test_args.parse_args()
|
args = test_args.parse_args()
|
||||||
|
|
||||||
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
)
|
||||||
|
if endpoint_api_key:
|
||||||
|
try:
|
||||||
call_generate(
|
call_generate(
|
||||||
api_key=args.api_key,
|
api_key=endpoint_api_key,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
call_generate_stream(
|
call_generate_stream(
|
||||||
api_key=args.api_key,
|
api_key=endpoint_api_key,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error during API call: {e}")
|
||||||
|
else:
|
||||||
|
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||||
|
|||||||
+10
-6
@@ -14,8 +14,15 @@ from .data_types import InputData
|
|||||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||||
|
|
||||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
|
MODEL_SERVER_START_LOG_MSG = [
|
||||||
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
|
'"message":"Connected","target":"text_generation_router"',
|
||||||
|
'"message":"Connected","target":"text_generation_router::server"',
|
||||||
|
]
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
|
"Error: WebserverFailed",
|
||||||
|
"Error: DownloadError",
|
||||||
|
"Error: ShardCannotStart",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -99,10 +106,7 @@ backend = Backend(
|
|||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
log_actions=[
|
log_actions=[
|
||||||
*[
|
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||||
(LogAction.ModelLoaded, info_msg)
|
|
||||||
for info_msg in MODEL_SERVER_START_LOG_MSG
|
|
||||||
],
|
|
||||||
(LogAction.Info, '"message":"Download'),
|
(LogAction.Info, '"message":"Download'),
|
||||||
*[
|
*[
|
||||||
(LogAction.ModelError, error_msg)
|
(LogAction.ModelError, error_msg)
|
||||||
|
|||||||
Reference in New Issue
Block a user