Addresses breaking change in core pyworker (#22)
* Addresses breaking change in test_utils.py Endpoint.get_endpoint_api_key() now requires instance Moves the call to this function out of the APIClient and into main * Ensure make_benchmark_payload has a value to calculate the workload --------- Co-authored-by: Nader Arbabian <nader@vast.ai>
This commit is contained in:
committed by
Nader Arbabian
parent
be2aafdb1f
commit
e0be45f39a
+1
-1
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
|||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app, handler_cancellation=True)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
|
|||||||
+19
-15
@@ -28,24 +28,16 @@ class APIClient:
|
|||||||
DEFAULT_TIMEOUT = 4
|
DEFAULT_TIMEOUT = 4
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str
|
self,
|
||||||
|
endpoint_group_name: str,
|
||||||
|
api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
endpoint_api_key: str,
|
||||||
):
|
):
|
||||||
self.endpoint_group_name = endpoint_group_name
|
self.endpoint_group_name = endpoint_group_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.instance = instance
|
self.endpoint_api_key = endpoint_api_key
|
||||||
self.endpoint_api_key = self._get_endpoint_api_key()
|
|
||||||
|
|
||||||
def _get_endpoint_api_key(self) -> Optional[str]:
|
|
||||||
"""Get the endpoint API key"""
|
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
|
||||||
endpoint_name=self.endpoint_group_name,
|
|
||||||
account_api_key=self.api_key,
|
|
||||||
instance=self.instance,
|
|
||||||
)
|
|
||||||
if not endpoint_api_key:
|
|
||||||
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
|
|
||||||
return endpoint_api_key
|
|
||||||
|
|
||||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||||
"""Get worker URL and auth data from routing service"""
|
"""Get worker URL and auth data from routing service"""
|
||||||
@@ -554,12 +546,24 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
instance=args.instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not endpoint_api_key:
|
||||||
|
log.error(
|
||||||
|
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Create the core API client
|
# Create the core API client
|
||||||
client = APIClient(
|
client = APIClient(
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
api_key=args.api_key,
|
api_key=args.api_key,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
instance=args.instance,
|
endpoint_api_key=endpoint_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tool manager and demo (passing the model parameter)
|
# Create tool manager and demo (passing the model parameter)
|
||||||
|
|||||||
@@ -124,7 +124,12 @@ class CompletionsData(GenericData):
|
|||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
test_input = {"model": model, "prompt": prompt, "temperature": 0.7}
|
test_input = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
|
}
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
@@ -158,6 +163,7 @@ class ChatCompletionsData(GenericData):
|
|||||||
"model": model,
|
"model": model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user