Addresses breaking change in core pyworker (#22)
* Addresses breaking change in test_utils.py Endpoint.get_endpoint_api_key() now requires instance Moves the call to this function out of the APIClient and into main * Ensure make_benchmark_payload has a value to calculate the workload --------- Co-authored-by: Nader Arbabian <nader@vast.ai>
This commit is contained in:
committed by
Nader Arbabian
parent
be2aafdb1f
commit
e0be45f39a
+1
-1
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
log.debug("starting server...")
|
||||
app = web.Application()
|
||||
app.add_routes(routes)
|
||||
runner = web.AppRunner(app, handler_cancellation=True)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(
|
||||
runner,
|
||||
|
||||
+19
-15
@@ -28,24 +28,16 @@ class APIClient:
|
||||
DEFAULT_TIMEOUT = 4
|
||||
|
||||
def __init__(
|
||||
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str
|
||||
self,
|
||||
endpoint_group_name: str,
|
||||
api_key: str,
|
||||
server_url: str,
|
||||
endpoint_api_key: str,
|
||||
):
|
||||
self.endpoint_group_name = endpoint_group_name
|
||||
self.api_key = api_key
|
||||
self.server_url = server_url
|
||||
self.instance = instance
|
||||
self.endpoint_api_key = self._get_endpoint_api_key()
|
||||
|
||||
def _get_endpoint_api_key(self) -> Optional[str]:
|
||||
"""Get the endpoint API key"""
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=self.endpoint_group_name,
|
||||
account_api_key=self.api_key,
|
||||
instance=self.instance,
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
|
||||
return endpoint_api_key
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
|
||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||
"""Get worker URL and auth data from routing service"""
|
||||
@@ -554,12 +546,24 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
if not endpoint_api_key:
|
||||
log.error(
|
||||
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create the core API client
|
||||
client = APIClient(
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
api_key=args.api_key,
|
||||
server_url=args.server_url,
|
||||
instance=args.instance,
|
||||
endpoint_api_key=endpoint_api_key,
|
||||
)
|
||||
|
||||
# Create tool manager and demo (passing the model parameter)
|
||||
|
||||
@@ -124,7 +124,12 @@ class CompletionsData(GenericData):
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
test_input = {"model": model, "prompt": prompt, "temperature": 0.7}
|
||||
test_input = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@@ -158,6 +163,7 @@ class ChatCompletionsData(GenericData):
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user