Compare commits

...

3 Commits

Author SHA1 Message Date
Nader Arbabian d3be9fe7db redo metrics tracking for requests, fixes bug wherere some requests were marked as pending, even though they had finished 2025-07-30 18:56:51 -07:00
Rob Ballantyne e0be45f39a Addresses breaking change in core pyworker (#22)
* Addresses breaking change in test_utils.py

Endpoint.get_endpoint_api_key() now requires instance

Moves the call to this function out of the APIClient and into main

* Ensure make_benchmark_payload has a value to calculate the workload

---------

Co-authored-by: Nader Arbabian <nader@vast.ai>
2025-07-18 16:11:10 -07:00
Nader Arbabian be2aafdb1f fix pyright errors + revert to old way of handling cancelled api requests (#23) 2025-07-17 16:59:06 -07:00
10 changed files with 306 additions and 273 deletions
+22 -17
View File
@@ -5,7 +5,7 @@ import base64
import subprocess
import dataclasses
import logging
from asyncio import sleep, gather, Semaphore
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
@@ -123,6 +123,12 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
@@ -135,7 +141,6 @@ class Backend:
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try:
start_time = time.time()
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
@@ -147,19 +152,17 @@ class Backend:
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_end(
workload=workload,
req_response_time=time.time() - start_time,
reqnum=auth_data.reqnum,
)
self.metrics._request_success(workload=workload)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(
workload=workload, reqnum=auth_data.reqnum
)
self.metrics._request_errored(workload=workload)
return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
###########
@@ -168,16 +171,18 @@ class Backend:
return web.Response(status=401)
try:
return await make_request()
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
if request.task.cancelled():
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(
workload=workload, reqnum=auth_data.reqnum
)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
+7 -4
View File
@@ -8,7 +8,6 @@ from aiohttp import web, ClientResponse
import inspect
import psutil
import requests
"""
@@ -206,13 +205,13 @@ class ModelMetrics:
workload_received: float
workload_cancelled: float
workload_errored: float
workload_pending: float
# these are not
cur_perf: float
workload_pending: float
error_msg: Optional[str]
max_throughput: float
requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set)
last_update: float = field(default_factory=time.time)
@classmethod
def empty(cls):
@@ -221,12 +220,15 @@ class ModelMetrics:
workload_served=0.0,
workload_cancelled=0.0,
workload_errored=0.0,
cur_perf=0.0,
workload_received=0.0,
error_msg=None,
max_throughput=0.0,
)
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property
def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0)
@@ -240,6 +242,7 @@ class ModelMetrics:
self.workload_received = 0
self.workload_cancelled = 0
self.workload_errored = 0
self.last_update = time.time()
@dataclass
+10 -12
View File
@@ -46,33 +46,31 @@ class Metrics:
self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum)
def _request_end(
self, workload: float, req_response_time: float, reqnum: int
) -> None:
def _request_end(self, workload: float, reqnum: int) -> None:
"""
this function is called after a response from model API is received.
this function is called after handling of a request ends, regardless of the outcome
"""
self.model_metrics.workload_served += workload
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.cur_perf = workload / req_response_time
def _request_success(self, workload: float) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += workload
self.update_pending = True
def _request_errored(self, workload: float, reqnum: int) -> None:
def _request_errored(self, workload: float) -> None:
"""
this function is called if model API returns an error
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_errored += workload
self.model_metrics.requests_working.discard(reqnum)
def _request_canceled(self, workload: float, reqnum: int) -> None:
def _request_canceled(self, workload: float) -> None:
"""
this function is called if client drops connection before model API has responded
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_cancelled += workload
self.model_metrics.requests_working.discard(reqnum)
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True:
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app, handler_cancellation=True)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
+2 -2
View File
@@ -1,4 +1,4 @@
aiohttp~=3.11
aiohttp[speedups]==3.10.1
anyio~=4.4
lib~=4.0
nltk~=3.9
@@ -6,5 +6,5 @@ psutil~=6.0
pycryptodome~=3.20
Requests~=2.32
transformers~=4.52
utils~=1.0
utils==1.0.*
hf_transfer>=0.1.9
+101 -85
View File
@@ -19,6 +19,7 @@ COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient:
"""Lightweight client focused solely on API communication"""
@@ -26,21 +27,17 @@ class APIClient:
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(self, endpoint_group_name: str, api_key: str, server_url: str):
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = self._get_endpoint_api_key()
def _get_endpoint_api_key(self) -> Optional[str]:
"""Get the endpoint API key"""
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
return endpoint_api_key
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
@@ -71,21 +68,21 @@ class APIClient:
"url": message["url"],
}
def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST",
stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]:
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get('max_tokens')
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {
"payload": {
"input": payload
},
"auth_data": auth_data
}
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
@@ -124,23 +121,22 @@ class APIClient:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload,
endpoint="/v1/completions",
stream=config.stream
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload,
endpoint="/v1/chat/completions",
stream=config.stream
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
@@ -151,7 +147,9 @@ class ToolManager:
def list_files() -> str:
"""Execute ls on current directory"""
try:
result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10)
result = subprocess.run(
["ls", "-la", "."], capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
return result.stdout
else:
@@ -162,18 +160,16 @@ class ToolManager:
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition"""
return [{
return [
{
"type": "function",
"function": {
"name": "list_files",
"description": "List files and directories in the cwd",
"parameters": {
"type": "object",
"properties": {},
"required": []
"parameters": {"type": "object", "properties": {}, "required": []},
},
}
}
}]
]
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
@@ -188,12 +184,16 @@ class ToolManager:
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None):
def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client
self.model = model
self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str:
def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
@@ -260,27 +260,25 @@ class APIDemo:
return full_response
def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...")
# Try a simple request with minimal tools to test support
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [{
minimal_tool = [
{
"type": "function",
"function": {
"name": "test_function",
"description": "Test function"
"function": {"name": "test_function", "description": "Test function"},
}
}]
]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none" # Don't actually call the tool
tool_choice="none", # Don't actually call the tool
)
try:
@@ -297,12 +295,12 @@ class APIDemo:
print("=" * 60)
config = CompletionConfig(
model=self.model,
prompt=COMPLETIONS_PROMPT,
stream=False
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'")
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
@@ -316,7 +314,9 @@ class APIDemo:
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
@@ -334,6 +334,7 @@ class APIDemo:
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
@@ -342,7 +343,9 @@ class APIDemo:
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
@@ -353,8 +356,6 @@ class APIDemo:
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60)
@@ -366,15 +367,13 @@ class APIDemo:
return
# Request with tool available
messages = [
{"role": "user", "content": TOOLS_PROMPT}
]
messages = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto"
tool_choice="auto",
)
log.info(f"Making initial request with tool using model '{self.model}'...")
@@ -391,7 +390,9 @@ class APIDemo:
# Check for tool calls
tool_calls = message.get("tool_calls")
if not tool_calls:
raise ValueError("No tool calls made - model may not support function calling")
raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}")
@@ -405,17 +406,19 @@ class APIDemo:
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append({
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result
})
"content": tool_result,
}
)
# Get final response
final_config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition()
tools=self.tool_manager.get_ls_tool_definition(),
)
print("Getting final response...")
@@ -447,10 +450,10 @@ class APIDemo:
try:
user_input = input("You: ").strip()
if user_input.lower() == 'quit':
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif user_input.lower() == 'clear':
elif user_input.lower() == "clear":
messages = []
print("Chat history cleared")
continue
@@ -460,16 +463,15 @@ class APIDemo:
messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model,
messages=messages,
stream=True,
temperature=0.7
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(response, show_reasoning=True)
assistant_content = self.handle_streaming_response(
response, show_reasoning=True
)
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
@@ -488,44 +490,43 @@ def main():
# Add mandatory model argument
test_args.add_argument(
"--model",
required=True,
help="Model to use for requests (required)"
"--model", required=True, help="Model to use for requests (required)"
)
# Add test mode arguments
test_args.add_argument(
"--completion",
action="store_true",
help="Test completions endpoint"
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)"
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming"
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)"
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session"
help="Start interactive streaming chat session",
)
args = test_args.parse_args()
# Check that only one test mode is selected
test_modes = [
args.completion, args.chat, args.chat_stream,
args.tools, args.interactive
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
@@ -536,18 +537,33 @@ def main():
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session")
print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT")
print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1)
elif selected_count > 1:
print("Please specify exactly one test mode")
sys.exit(1)
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=args.server_url
server_url=args.server_url,
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
+7 -3
View File
@@ -6,8 +6,10 @@ from typing import Optional, List, Dict, Any
class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj):
return {field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)}
return {
field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
@@ -27,6 +29,7 @@ class SerializableDataclass:
@dataclass
class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests"""
model: str
prompt: str = "Hello"
max_tokens: int = 256
@@ -39,8 +42,9 @@ class CompletionConfig(SerializableDataclass):
@dataclass
class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests"""
model: str
messages: list = None
messages: list = field(default_factory=list)
max_tokens: int = 2096
temperature: float = 0.7
top_k: int = 20
+21 -16
View File
@@ -2,7 +2,7 @@ import os, json, random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any
from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse
import nltk
import logging
@@ -14,15 +14,15 @@ log = logging.getLogger(__name__)
"""
Generic dataclass accepts any dictionary in input.
"""
@dataclass
class GenericData(ApiPayload, ABC):
input: Dict[str, Any]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls(
input=data["input"]
)
return cls(input=data["input"])
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
@@ -39,9 +39,7 @@ class GenericData(ApiPayload, ABC):
try:
# Create clean data dict and delegate to from_dict
clean_data = {
"input": json_msg["input"]
}
clean_data = {"input": json_msg["input"]}
return cls.from_dict(clean_data)
@@ -60,6 +58,7 @@ class GenericData(ApiPayload, ABC):
def count_workload(self) -> int:
return self.input.get("max_tokens", 0)
@dataclass
class GenericHandler(EndpointHandler[GenericData], ABC):
@@ -69,8 +68,8 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
pass
@property
def healthcheck_endpoint(self) -> str:
return os.environ.get('MODEL_HEALTH_ENDPOINT')
def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod
def payload_cls(cls) -> Type[GenericData]:
@@ -87,10 +86,10 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
case 200:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream" or
model_response.content_type == "application/x-ndjson" or
model_response.headers.get("Transfer-Encoding") == "chunked" or
"stream" in model_response.content_type.lower()
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
@@ -109,12 +108,13 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
return web.Response(
body=content,
status=200,
content_type=model_response.content_type
content_type=model_response.content_type,
)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclass
class CompletionsData(GenericData):
@classmethod
@@ -127,10 +127,12 @@ class CompletionsData(GenericData):
test_input = {
"model": model,
"prompt": prompt,
"temperature": 0.7
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class CompletionsHandler(GenericHandler):
@property
@@ -144,6 +146,7 @@ class CompletionsHandler(GenericHandler):
def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test()
@dataclass
class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation"""
@@ -159,10 +162,12 @@ class ChatCompletionsData(GenericData):
test_input = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class ChatCompletionsHandler(GenericHandler):
@property
+5 -3
View File
@@ -20,7 +20,7 @@ MODEL_SERVER_ERROR_LOG_MSGS = [
"stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI
"Error: ShardCannotStart", #TGI
"Error: ShardCannotStart", # TGI
]
logging.basicConfig(
@@ -31,8 +31,8 @@ logging.basicConfig(
log = logging.getLogger(__file__)
backend = Backend(
model_server_url=os.environ.get("MODEL_SERVER_URL"),
model_log_file=os.environ.get("MODEL_LOG"),
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
@@ -45,9 +45,11 @@ backend = Backend(
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
+2 -2
View File
@@ -13,14 +13,14 @@ if __name__ == "__main__":
"--model",
dest="model",
required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)"
help="Model to use for completions request (required if MODEL_NAME env var not set)",
)
# Parse known args to get model early, before test_load_cmd adds its args
known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided
if hasattr(known_args, 'model') and known_args.model:
if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}")