Endpoint update pr one (#1)

* Added endpoint flexibility along with existing log. extended the log support

* Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

* Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Endpoint Utils and API changes
This commit is contained in:
Abiola Akinnubi
2025-06-02 17:13:25 -07:00
committed by Nader Arbabian
parent b1ca68c349
commit 71ed54ebe4
11 changed files with 212 additions and 61 deletions
+31 -9
View File
@@ -1,12 +1,20 @@
import logging
from urllib.parse import urljoin
import requests
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
"""
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_default_workflow(
@@ -24,6 +32,7 @@ def call_default_workflow(
json=route_payload,
timeout=4,
)
response.raise_for_status()
message = response.json()
url = message["url"]
auth_data = dict(
@@ -43,6 +52,7 @@ def call_default_workflow(
url,
json=req_data,
)
response.raise_for_status()
print_truncate_res(str(response.json()))
@@ -61,6 +71,7 @@ def call_custom_workflow_for_sd3(
json=route_payload,
timeout=4,
)
response.raise_for_status()
message = response.json()
url = message["url"]
auth_data = dict(
@@ -131,6 +142,7 @@ def call_custom_workflow_for_sd3(
url,
json=req_data,
)
response.raise_for_status()
print_truncate_res(str(response.json()))
@@ -138,13 +150,23 @@ if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
call_default_workflow(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_custom_workflow_for_sd3(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
)
if endpoint_api_key:
try:
call_default_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_custom_workflow_for_sd3(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
+4 -4
View File
@@ -69,11 +69,11 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
return DefaultComfyWorkflowData
@@ -93,11 +93,11 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
return CustomComfyWorkflowData
+4 -4
View File
@@ -49,11 +49,11 @@ class GenerateHandler(EndpointHandler[InputData]):
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -97,11 +97,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
+47 -19
View File
@@ -1,8 +1,16 @@
import logging
import sys
import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
@@ -18,28 +26,31 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
)
response = requests.post(url, json=req_data)
response.raise_for_status()
res = response.json()
print(res)
def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str):
def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
@@ -52,6 +63,7 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
print(f"url: {url}")
@@ -66,12 +78,17 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
continue
print()
@@ -79,13 +96,24 @@ if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
call_generate(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
+13 -9
View File
@@ -14,8 +14,15 @@ from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
MODEL_SERVER_START_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
logging.basicConfig(
@@ -36,7 +43,7 @@ class GenerateHandler(EndpointHandler[InputData]):
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -62,11 +69,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -99,10 +106,7 @@ backend = Backend(
allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[
(LogAction.ModelLoaded, info_msg)
for info_msg in MODEL_SERVER_START_LOG_MSG
],
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)