Endpoint update pr one (#1)
* Added endpoint flexibility along with existing log. extended the log support * Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support * Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Added endpoint flexibility along with existing log. extended the log support Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Endpoint Utils and API changes
This commit is contained in:
committed by
Nader Arbabian
parent
b1ca68c349
commit
71ed54ebe4
+47
-19
@@ -1,8 +1,16 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||
@@ -18,28 +26,31 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
url=url,
|
||||
)
|
||||
|
||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
)
|
||||
response = requests.post(url, json=req_data)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
print(res)
|
||||
|
||||
|
||||
def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str):
|
||||
def call_generate_stream(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/generate_stream"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
@@ -52,6 +63,7 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
print(f"url: {url}")
|
||||
@@ -66,12 +78,17 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
response = requests.post(url, json=req_data, stream=True)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
for line in response.iter_lines():
|
||||
payload = line.decode().lstrip("data:").rstrip()
|
||||
if payload:
|
||||
data = json.loads(payload)
|
||||
print(data["token"]["text"], end="")
|
||||
sys.stdout.flush()
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
print(data["token"]["text"], end="")
|
||||
sys.stdout.flush()
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
log.warning(f"Failed to parse streaming response: {e}")
|
||||
continue
|
||||
print()
|
||||
|
||||
|
||||
@@ -79,13 +96,24 @@ if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
call_generate(
|
||||
api_key=args.api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_generate_stream(
|
||||
api_key=args.api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
)
|
||||
if endpoint_api_key:
|
||||
try:
|
||||
call_generate(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_generate_stream(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during API call: {e}")
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||
|
||||
+13
-9
@@ -14,8 +14,15 @@ from .data_types import InputData
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
|
||||
MODEL_SERVER_START_LOG_MSG = [
|
||||
'"message":"Connected","target":"text_generation_router"',
|
||||
'"message":"Connected","target":"text_generation_router::server"',
|
||||
]
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Error: WebserverFailed",
|
||||
"Error: DownloadError",
|
||||
"Error: ShardCannotStart",
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -36,7 +43,7 @@ class GenerateHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
@@ -62,11 +69,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
@@ -99,10 +106,7 @@ backend = Backend(
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[
|
||||
(LogAction.ModelLoaded, info_msg)
|
||||
for info_msg in MODEL_SERVER_START_LOG_MSG
|
||||
],
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
|
||||
Reference in New Issue
Block a user