Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9773e5f67b | |||
| e0be45f39a | |||
| be2aafdb1f |
+16
-2
@@ -5,7 +5,7 @@ import base64
|
|||||||
import subprocess
|
import subprocess
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from asyncio import sleep, gather, Semaphore
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
@@ -123,6 +123,12 @@ class Backend:
|
|||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
|
|
||||||
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
|
await request.wait_for_disconnection()
|
||||||
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
|
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
||||||
|
return web.Response(status=500)
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
log.debug(f"got request, {auth_data.reqnum}")
|
log.debug(f"got request, {auth_data.reqnum}")
|
||||||
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
||||||
@@ -168,7 +174,15 @@ class Backend:
|
|||||||
return web.Response(status=401)
|
return web.Response(status=401)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await make_request()
|
done, pending = await wait(
|
||||||
|
[
|
||||||
|
create_task(make_request()),
|
||||||
|
create_task(cancel_api_call_if_disconnected()),
|
||||||
|
],
|
||||||
|
return_when=FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
[task.cancel() for task in pending]
|
||||||
|
return done.pop().result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|||||||
+1
-1
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
|||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app, handler_cancellation=True)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from collections import Counter
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from lib.data_types import AuthData, ApiPayload
|
from lib.data_types import AuthData, ApiPayload
|
||||||
@@ -120,9 +121,11 @@ class ClientState:
|
|||||||
self.url = worker_address
|
self.url = worker_address
|
||||||
url = urljoin(worker_address, self.worker_endpoint)
|
url = urljoin(worker_address, self.worker_endpoint)
|
||||||
self.status = ClientStatus.Generating
|
self.status = ClientStatus.Generating
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
self.infer_error.append(
|
self.infer_error.append(
|
||||||
|
|||||||
+1
-1
@@ -1,4 +1,4 @@
|
|||||||
aiohttp~=3.11
|
aiohttp==3.10.1
|
||||||
anyio~=4.4
|
anyio~=4.4
|
||||||
lib~=4.0
|
lib~=4.0
|
||||||
nltk~=3.9
|
nltk~=3.9
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
import tempfile
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_cert_file_path():
|
||||||
|
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||||
|
response = requests.get(cert_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
# Use a temporary file that is not deleted on close
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
return f.name
|
||||||
@@ -5,6 +5,7 @@ import requests
|
|||||||
|
|
||||||
from lib.test_utils import print_truncate_res
|
from lib.test_utils import print_truncate_res
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||||
@@ -51,6 +52,7 @@ def call_default_workflow(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
|
|||||||
+108
-87
@@ -6,6 +6,7 @@ from urllib.parse import urljoin
|
|||||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -19,6 +20,7 @@ COMPLETIONS_PROMPT = "the capital of USA is"
|
|||||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
||||||
|
|
||||||
|
|
||||||
class APIClient:
|
class APIClient:
|
||||||
"""Lightweight client focused solely on API communication"""
|
"""Lightweight client focused solely on API communication"""
|
||||||
|
|
||||||
@@ -26,21 +28,17 @@ class APIClient:
|
|||||||
DEFAULT_COST = 100
|
DEFAULT_COST = 100
|
||||||
DEFAULT_TIMEOUT = 4
|
DEFAULT_TIMEOUT = 4
|
||||||
|
|
||||||
def __init__(self, endpoint_group_name: str, api_key: str, server_url: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint_group_name: str,
|
||||||
|
api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
endpoint_api_key: str,
|
||||||
|
):
|
||||||
self.endpoint_group_name = endpoint_group_name
|
self.endpoint_group_name = endpoint_group_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.endpoint_api_key = self._get_endpoint_api_key()
|
self.endpoint_api_key = endpoint_api_key
|
||||||
|
|
||||||
def _get_endpoint_api_key(self) -> Optional[str]:
|
|
||||||
"""Get the endpoint API key"""
|
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
|
||||||
endpoint_name=self.endpoint_group_name,
|
|
||||||
account_api_key=self.api_key,
|
|
||||||
)
|
|
||||||
if not endpoint_api_key:
|
|
||||||
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
|
|
||||||
return endpoint_api_key
|
|
||||||
|
|
||||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||||
"""Get worker URL and auth data from routing service"""
|
"""Get worker URL and auth data from routing service"""
|
||||||
@@ -71,21 +69,21 @@ class APIClient:
|
|||||||
"url": message["url"],
|
"url": message["url"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST",
|
def _make_request(
|
||||||
stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]:
|
self,
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
endpoint: str,
|
||||||
|
method: str = "POST",
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||||
"""Make request directly to the specific worker endpoint"""
|
"""Make request directly to the specific worker endpoint"""
|
||||||
# Get worker URL and auth data
|
# Get worker URL and auth data
|
||||||
cost = payload.get('max_tokens')
|
cost = payload.get("max_tokens", self.DEFAULT_COST)
|
||||||
message = self._get_worker_url(cost=cost)
|
message = self._get_worker_url(cost=cost)
|
||||||
worker_url = message["url"]
|
worker_url = message["url"]
|
||||||
auth_data = self._create_auth_data(message)
|
auth_data = self._create_auth_data(message)
|
||||||
|
|
||||||
req_data = {
|
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
|
||||||
"payload": {
|
|
||||||
"input": payload
|
|
||||||
},
|
|
||||||
"auth_data": auth_data
|
|
||||||
}
|
|
||||||
|
|
||||||
url = urljoin(worker_url, endpoint)
|
url = urljoin(worker_url, endpoint)
|
||||||
log.debug(f"Making direct request to: {url}")
|
log.debug(f"Making direct request to: {url}")
|
||||||
@@ -93,9 +91,13 @@ class APIClient:
|
|||||||
|
|
||||||
# Make the request using the specified method
|
# Make the request using the specified method
|
||||||
if method.upper() == "POST":
|
if method.upper() == "POST":
|
||||||
response = requests.post(url, json=req_data, stream=stream)
|
response = requests.post(
|
||||||
|
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
elif method.upper() == "GET":
|
elif method.upper() == "GET":
|
||||||
response = requests.get(url, params=req_data, stream=stream)
|
response = requests.get(
|
||||||
|
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
@@ -124,23 +126,22 @@ class APIClient:
|
|||||||
log.error(f"Error handling streaming response: {e}")
|
log.error(f"Error handling streaming response: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def call_completions(
|
||||||
def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
|
self, config: CompletionConfig
|
||||||
|
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||||
payload = config.to_dict()
|
payload = config.to_dict()
|
||||||
|
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
payload=payload,
|
payload=payload, endpoint="/v1/completions", stream=config.stream
|
||||||
endpoint="/v1/completions",
|
|
||||||
stream=config.stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
|
def call_chat_completions(
|
||||||
|
self, config: ChatCompletionConfig
|
||||||
|
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||||
payload = config.to_dict()
|
payload = config.to_dict()
|
||||||
|
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
payload=payload,
|
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
|
||||||
endpoint="/v1/chat/completions",
|
|
||||||
stream=config.stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -151,7 +152,9 @@ class ToolManager:
|
|||||||
def list_files() -> str:
|
def list_files() -> str:
|
||||||
"""Execute ls on current directory"""
|
"""Execute ls on current directory"""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10)
|
result = subprocess.run(
|
||||||
|
["ls", "-la", "."], capture_output=True, text=True, timeout=10
|
||||||
|
)
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return result.stdout
|
return result.stdout
|
||||||
else:
|
else:
|
||||||
@@ -162,18 +165,16 @@ class ToolManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||||
"""Get the ls tool definition"""
|
"""Get the ls tool definition"""
|
||||||
return [{
|
return [
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "list_files",
|
"name": "list_files",
|
||||||
"description": "List files and directories in the cwd",
|
"description": "List files and directories in the cwd",
|
||||||
"parameters": {
|
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||||
"type": "object",
|
},
|
||||||
"properties": {},
|
|
||||||
"required": []
|
|
||||||
}
|
}
|
||||||
}
|
]
|
||||||
}]
|
|
||||||
|
|
||||||
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||||
"""Execute a tool call and return the result"""
|
"""Execute a tool call and return the result"""
|
||||||
@@ -188,12 +189,16 @@ class ToolManager:
|
|||||||
class APIDemo:
|
class APIDemo:
|
||||||
"""Demo and testing functionality for the API client"""
|
"""Demo and testing functionality for the API client"""
|
||||||
|
|
||||||
def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None):
|
def __init__(
|
||||||
|
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
|
||||||
|
):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tool_manager = tool_manager or ToolManager()
|
self.tool_manager = tool_manager or ToolManager()
|
||||||
|
|
||||||
def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str:
|
def handle_streaming_response(
|
||||||
|
self, response_stream, show_reasoning: bool = True
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Handle streaming chat response and display all output.
|
Handle streaming chat response and display all output.
|
||||||
"""
|
"""
|
||||||
@@ -260,27 +265,25 @@ class APIDemo:
|
|||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
|
|
||||||
def test_tool_support(self) -> bool:
|
def test_tool_support(self) -> bool:
|
||||||
"""Test if the endpoint supports function calling"""
|
"""Test if the endpoint supports function calling"""
|
||||||
log.debug("Testing endpoint tool calling support...")
|
log.debug("Testing endpoint tool calling support...")
|
||||||
|
|
||||||
# Try a simple request with minimal tools to test support
|
# Try a simple request with minimal tools to test support
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
minimal_tool = [{
|
minimal_tool = [
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {"name": "test_function", "description": "Test function"},
|
||||||
"name": "test_function",
|
|
||||||
"description": "Test function"
|
|
||||||
}
|
}
|
||||||
}]
|
]
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
config = ChatCompletionConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
tools=minimal_tool,
|
tools=minimal_tool,
|
||||||
tool_choice="none" # Don't actually call the tool
|
tool_choice="none", # Don't actually call the tool
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -297,12 +300,12 @@ class APIDemo:
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
config = CompletionConfig(
|
config = CompletionConfig(
|
||||||
model=self.model,
|
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
|
||||||
prompt=COMPLETIONS_PROMPT,
|
|
||||||
stream=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'")
|
log.info(
|
||||||
|
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
|
||||||
|
)
|
||||||
response = self.client.call_completions(config)
|
response = self.client.call_completions(config)
|
||||||
|
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
@@ -316,7 +319,9 @@ class APIDemo:
|
|||||||
Demo: test chat completions endpoint with optional streaming
|
Demo: test chat completions endpoint with optional streaming
|
||||||
"""
|
"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
print(
|
||||||
|
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
|
||||||
|
)
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
config = ChatCompletionConfig(
|
||||||
@@ -334,6 +339,7 @@ class APIDemo:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"\nError during streaming: {e}")
|
log.error(f"\nError during streaming: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -342,7 +348,9 @@ class APIDemo:
|
|||||||
choice = response.get("choices", [{}])[0]
|
choice = response.get("choices", [{}])[0]
|
||||||
message = choice.get("message", {})
|
message = choice.get("message", {})
|
||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
reasoning = message.get("reasoning_content", "") or message.get(
|
||||||
|
"reasoning", ""
|
||||||
|
)
|
||||||
|
|
||||||
if reasoning:
|
if reasoning:
|
||||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||||
@@ -353,8 +361,6 @@ class APIDemo:
|
|||||||
else:
|
else:
|
||||||
log.error("Unexpected response format")
|
log.error("Unexpected response format")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def demo_ls_tool(self) -> None:
|
def demo_ls_tool(self) -> None:
|
||||||
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
@@ -366,15 +372,13 @@ class APIDemo:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Request with tool available
|
# Request with tool available
|
||||||
messages = [
|
messages = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||||
{"role": "user", "content": TOOLS_PROMPT}
|
|
||||||
]
|
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
config = ChatCompletionConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tool_manager.get_ls_tool_definition(),
|
tools=self.tool_manager.get_ls_tool_definition(),
|
||||||
tool_choice="auto"
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Making initial request with tool using model '{self.model}'...")
|
log.info(f"Making initial request with tool using model '{self.model}'...")
|
||||||
@@ -391,7 +395,9 @@ class APIDemo:
|
|||||||
# Check for tool calls
|
# Check for tool calls
|
||||||
tool_calls = message.get("tool_calls")
|
tool_calls = message.get("tool_calls")
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
raise ValueError("No tool calls made - model may not support function calling")
|
raise ValueError(
|
||||||
|
"No tool calls made - model may not support function calling"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Tool calls detected: {len(tool_calls)}")
|
print(f"Tool calls detected: {len(tool_calls)}")
|
||||||
|
|
||||||
@@ -405,17 +411,19 @@ class APIDemo:
|
|||||||
|
|
||||||
# Add tool result and continue conversation
|
# Add tool result and continue conversation
|
||||||
messages.append(message) # Add assistant's message with tool call
|
messages.append(message) # Add assistant's message with tool call
|
||||||
messages.append({
|
messages.append(
|
||||||
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_call["id"],
|
"tool_call_id": tool_call["id"],
|
||||||
"content": tool_result
|
"content": tool_result,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Get final response
|
# Get final response
|
||||||
final_config = ChatCompletionConfig(
|
final_config = ChatCompletionConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tool_manager.get_ls_tool_definition()
|
tools=self.tool_manager.get_ls_tool_definition(),
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Getting final response...")
|
print("Getting final response...")
|
||||||
@@ -447,10 +455,10 @@ class APIDemo:
|
|||||||
try:
|
try:
|
||||||
user_input = input("You: ").strip()
|
user_input = input("You: ").strip()
|
||||||
|
|
||||||
if user_input.lower() == 'quit':
|
if user_input.lower() == "quit":
|
||||||
print("👋 Goodbye!")
|
print("👋 Goodbye!")
|
||||||
break
|
break
|
||||||
elif user_input.lower() == 'clear':
|
elif user_input.lower() == "clear":
|
||||||
messages = []
|
messages = []
|
||||||
print("Chat history cleared")
|
print("Chat history cleared")
|
||||||
continue
|
continue
|
||||||
@@ -460,16 +468,15 @@ class APIDemo:
|
|||||||
messages.append({"role": "user", "content": user_input})
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
config = ChatCompletionConfig(
|
||||||
model=self.model,
|
model=self.model, messages=messages, stream=True, temperature=0.7
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
|
||||||
temperature=0.7
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Assistant: ", end="", flush=True)
|
print("Assistant: ", end="", flush=True)
|
||||||
|
|
||||||
response = self.client.call_chat_completions(config)
|
response = self.client.call_chat_completions(config)
|
||||||
assistant_content = self.handle_streaming_response(response, show_reasoning=True)
|
assistant_content = self.handle_streaming_response(
|
||||||
|
response, show_reasoning=True
|
||||||
|
)
|
||||||
|
|
||||||
# Add assistant response to conversation history
|
# Add assistant response to conversation history
|
||||||
messages.append({"role": "assistant", "content": assistant_content})
|
messages.append({"role": "assistant", "content": assistant_content})
|
||||||
@@ -488,44 +495,43 @@ def main():
|
|||||||
|
|
||||||
# Add mandatory model argument
|
# Add mandatory model argument
|
||||||
test_args.add_argument(
|
test_args.add_argument(
|
||||||
"--model",
|
"--model", required=True, help="Model to use for requests (required)"
|
||||||
required=True,
|
|
||||||
help="Model to use for requests (required)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add test mode arguments
|
# Add test mode arguments
|
||||||
test_args.add_argument(
|
test_args.add_argument(
|
||||||
"--completion",
|
"--completion", action="store_true", help="Test completions endpoint"
|
||||||
action="store_true",
|
|
||||||
help="Test completions endpoint"
|
|
||||||
)
|
)
|
||||||
test_args.add_argument(
|
test_args.add_argument(
|
||||||
"--chat",
|
"--chat",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Test chat completions endpoint (non-streaming)"
|
help="Test chat completions endpoint (non-streaming)",
|
||||||
)
|
)
|
||||||
test_args.add_argument(
|
test_args.add_argument(
|
||||||
"--chat-stream",
|
"--chat-stream",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Test chat completions endpoint with streaming"
|
help="Test chat completions endpoint with streaming",
|
||||||
)
|
)
|
||||||
test_args.add_argument(
|
test_args.add_argument(
|
||||||
"--tools",
|
"--tools",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Test function calling with ls tool (non-streaming)"
|
help="Test function calling with ls tool (non-streaming)",
|
||||||
)
|
)
|
||||||
test_args.add_argument(
|
test_args.add_argument(
|
||||||
"--interactive",
|
"--interactive",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Start interactive streaming chat session"
|
help="Start interactive streaming chat session",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = test_args.parse_args()
|
args = test_args.parse_args()
|
||||||
|
|
||||||
# Check that only one test mode is selected
|
# Check that only one test mode is selected
|
||||||
test_modes = [
|
test_modes = [
|
||||||
args.completion, args.chat, args.chat_stream,
|
args.completion,
|
||||||
args.tools, args.interactive
|
args.chat,
|
||||||
|
args.chat_stream,
|
||||||
|
args.tools,
|
||||||
|
args.interactive,
|
||||||
]
|
]
|
||||||
selected_count = sum(test_modes)
|
selected_count = sum(test_modes)
|
||||||
|
|
||||||
@@ -536,18 +542,33 @@ def main():
|
|||||||
print(" --chat-stream : Test chat completions endpoint with streaming")
|
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||||
print(" --tools : Test function calling with ls tool (non-streaming)")
|
print(" --tools : Test function calling with ls tool (non-streaming)")
|
||||||
print(" --interactive : Start interactive streaming chat session")
|
print(" --interactive : Start interactive streaming chat session")
|
||||||
print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT")
|
print(
|
||||||
|
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif selected_count > 1:
|
elif selected_count > 1:
|
||||||
print("Please specify exactly one test mode")
|
print("Please specify exactly one test mode")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
instance=args.instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not endpoint_api_key:
|
||||||
|
log.error(
|
||||||
|
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Create the core API client
|
# Create the core API client
|
||||||
client = APIClient(
|
client = APIClient(
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
api_key=args.api_key,
|
api_key=args.api_key,
|
||||||
server_url=args.server_url
|
server_url=args.server_url,
|
||||||
|
endpoint_api_key=endpoint_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tool manager and demo (passing the model parameter)
|
# Create tool manager and demo (passing the model parameter)
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ from typing import Optional, List, Dict, Any
|
|||||||
class SerializableDataclass:
|
class SerializableDataclass:
|
||||||
def _serialize_recursive(self, obj: Any) -> Any:
|
def _serialize_recursive(self, obj: Any) -> Any:
|
||||||
if is_dataclass(obj):
|
if is_dataclass(obj):
|
||||||
return {field.name: self._serialize_recursive(getattr(obj, field.name))
|
return {
|
||||||
for field in fields(obj)}
|
field.name: self._serialize_recursive(getattr(obj, field.name))
|
||||||
|
for field in fields(obj)
|
||||||
|
}
|
||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
||||||
elif isinstance(obj, (list, tuple)):
|
elif isinstance(obj, (list, tuple)):
|
||||||
@@ -27,6 +29,7 @@ class SerializableDataclass:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class CompletionConfig(SerializableDataclass):
|
class CompletionConfig(SerializableDataclass):
|
||||||
"""Configuration for completion requests"""
|
"""Configuration for completion requests"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
prompt: str = "Hello"
|
prompt: str = "Hello"
|
||||||
max_tokens: int = 256
|
max_tokens: int = 256
|
||||||
@@ -39,8 +42,9 @@ class CompletionConfig(SerializableDataclass):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ChatCompletionConfig(SerializableDataclass):
|
class ChatCompletionConfig(SerializableDataclass):
|
||||||
"""Configuration for chat completion requests"""
|
"""Configuration for chat completion requests"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
messages: list = None
|
messages: list = field(default_factory=list)
|
||||||
max_tokens: int = 2096
|
max_tokens: int = 2096
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
top_k: int = 20
|
top_k: int = 20
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os, json, random
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
||||||
from typing import Union, Type, Dict, Any
|
from typing import Union, Type, Dict, Any, Optional
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import nltk
|
import nltk
|
||||||
import logging
|
import logging
|
||||||
@@ -14,15 +14,15 @@ log = logging.getLogger(__name__)
|
|||||||
"""
|
"""
|
||||||
Generic dataclass accepts any dictionary in input.
|
Generic dataclass accepts any dictionary in input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenericData(ApiPayload, ABC):
|
class GenericData(ApiPayload, ABC):
|
||||||
input: Dict[str, Any]
|
input: Dict[str, Any]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
||||||
return cls(
|
return cls(input=data["input"])
|
||||||
input=data["input"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
||||||
@@ -39,9 +39,7 @@ class GenericData(ApiPayload, ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Create clean data dict and delegate to from_dict
|
# Create clean data dict and delegate to from_dict
|
||||||
clean_data = {
|
clean_data = {"input": json_msg["input"]}
|
||||||
"input": json_msg["input"]
|
|
||||||
}
|
|
||||||
|
|
||||||
return cls.from_dict(clean_data)
|
return cls.from_dict(clean_data)
|
||||||
|
|
||||||
@@ -60,6 +58,7 @@ class GenericData(ApiPayload, ABC):
|
|||||||
def count_workload(self) -> int:
|
def count_workload(self) -> int:
|
||||||
return self.input.get("max_tokens", 0)
|
return self.input.get("max_tokens", 0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenericHandler(EndpointHandler[GenericData], ABC):
|
class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||||
|
|
||||||
@@ -69,8 +68,8 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def healthcheck_endpoint(self) -> str:
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
return os.environ.get('MODEL_HEALTH_ENDPOINT')
|
return os.environ.get("MODEL_HEALTH_ENDPOINT")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[GenericData]:
|
def payload_cls(cls) -> Type[GenericData]:
|
||||||
@@ -87,10 +86,10 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
case 200:
|
case 200:
|
||||||
# Check if the response is actually streaming based on response headers/content-type
|
# Check if the response is actually streaming based on response headers/content-type
|
||||||
is_streaming_response = (
|
is_streaming_response = (
|
||||||
model_response.content_type == "text/event-stream" or
|
model_response.content_type == "text/event-stream"
|
||||||
model_response.content_type == "application/x-ndjson" or
|
or model_response.content_type == "application/x-ndjson"
|
||||||
model_response.headers.get("Transfer-Encoding") == "chunked" or
|
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||||
"stream" in model_response.content_type.lower()
|
or "stream" in model_response.content_type.lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_streaming_response:
|
if is_streaming_response:
|
||||||
@@ -109,12 +108,13 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
return web.Response(
|
return web.Response(
|
||||||
body=content,
|
body=content,
|
||||||
status=200,
|
status=200,
|
||||||
content_type=model_response.content_type
|
content_type=model_response.content_type,
|
||||||
)
|
)
|
||||||
case code:
|
case code:
|
||||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
return web.Response(status=code)
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompletionsData(GenericData):
|
class CompletionsData(GenericData):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -127,10 +127,12 @@ class CompletionsData(GenericData):
|
|||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": 0.7
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompletionsHandler(GenericHandler):
|
class CompletionsHandler(GenericHandler):
|
||||||
@property
|
@property
|
||||||
@@ -144,6 +146,7 @@ class CompletionsHandler(GenericHandler):
|
|||||||
def make_benchmark_payload(self) -> CompletionsData:
|
def make_benchmark_payload(self) -> CompletionsData:
|
||||||
return CompletionsData.for_test()
|
return CompletionsData.for_test()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatCompletionsData(GenericData):
|
class ChatCompletionsData(GenericData):
|
||||||
"""Chat completions-specific data implementation"""
|
"""Chat completions-specific data implementation"""
|
||||||
@@ -159,10 +162,12 @@ class ChatCompletionsData(GenericData):
|
|||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"temperature": 0.7
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatCompletionsHandler(GenericHandler):
|
class ChatCompletionsHandler(GenericHandler):
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ logging.basicConfig(
|
|||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
backend = Backend(
|
backend = Backend(
|
||||||
model_server_url=os.environ.get("MODEL_SERVER_URL"),
|
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||||
model_log_file=os.environ.get("MODEL_LOG"),
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
log_actions=[
|
log_actions=[
|
||||||
@@ -45,9 +45,11 @@ backend = Backend(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_ping(_):
|
async def handle_ping(_):
|
||||||
return web.Response(body="pong")
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
routes = [
|
routes = [
|
||||||
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
||||||
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
||||||
|
|||||||
@@ -13,14 +13,14 @@ if __name__ == "__main__":
|
|||||||
"--model",
|
"--model",
|
||||||
dest="model",
|
dest="model",
|
||||||
required=not model_name_set,
|
required=not model_name_set,
|
||||||
help="Model to use for completions request (required if MODEL_NAME env var not set)"
|
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse known args to get model early, before test_load_cmd adds its args
|
# Parse known args to get model early, before test_load_cmd adds its args
|
||||||
known_args, _ = test_args.parse_known_args()
|
known_args, _ = test_args.parse_known_args()
|
||||||
|
|
||||||
# Set environment variable if model was provided
|
# Set environment variable if model was provided
|
||||||
if hasattr(known_args, 'model') and known_args.model:
|
if hasattr(known_args, "model") and known_args.model:
|
||||||
os.environ["MODEL_NAME"] = known_args.model
|
os.environ["MODEL_NAME"] = known_args.model
|
||||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
|||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
print(f"url: {url}")
|
print(f"url: {url}")
|
||||||
response = requests.post(url, json=req_data)
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
res = response.json()
|
res = response.json()
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
Reference in New Issue
Block a user