Files
pyworker/workers/tgi/client.py
T

121 lines
3.6 KiB
Python
Raw Normal View History

2025-06-02 17:13:25 -07:00
import logging
2024-09-04 11:19:30 -07:00
import sys
import json
from urllib.parse import urljoin
import requests
2025-06-02 17:13:25 -07:00
from utils.endpoint_util import Endpoint
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
2024-09-04 11:19:30 -07:00
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
WORKER_ENDPOINT = "/generate"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
2025-06-02 17:13:25 -07:00
response.raise_for_status() # Raise an exception for bad status codes
2024-09-04 11:19:30 -07:00
message = response.json()
url = message["url"]
2025-06-02 17:13:25 -07:00
2024-09-04 11:19:30 -07:00
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
2025-06-02 17:13:25 -07:00
url=url,
2024-09-04 11:19:30 -07:00
)
2025-06-02 17:13:25 -07:00
2024-09-04 11:19:30 -07:00
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
2025-06-02 17:13:25 -07:00
response = requests.post(url, json=req_data)
response.raise_for_status()
2024-09-04 11:19:30 -07:00
res = response.json()
print(res)
2025-06-02 17:13:25 -07:00
def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
2024-09-04 11:19:30 -07:00
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
2025-06-02 17:13:25 -07:00
response.raise_for_status() # Raise an exception for bad status codes
2024-09-04 11:19:30 -07:00
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
2025-06-02 17:13:25 -07:00
response.raise_for_status() # Raise an exception for bad status codes
2024-09-04 11:19:30 -07:00
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
2025-06-02 17:13:25 -07:00
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
continue
2024-09-04 11:19:30 -07:00
print()
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
2025-06-02 17:13:25 -07:00
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
2025-07-11 15:04:54 -07:00
instance=args.instance,
2024-09-04 11:19:30 -07:00
)
2025-06-02 17:13:25 -07:00
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")