fix completions and interactive client
This commit is contained in:
+20
-7
@@ -16,6 +16,24 @@ class Endpoint:
|
|||||||
Utility class for handling endpoint operations.
|
Utility class for handling endpoint operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_autoscaler_server_url(instance: str) -> str:
|
||||||
|
endpoints = {
|
||||||
|
"alpha": "run-alpha",
|
||||||
|
"candidate": "run-candidate",
|
||||||
|
"prod": "run",
|
||||||
|
}
|
||||||
|
return f"https://{endpoints[instance]}.vast.ai/"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_server_url(instance: str) -> str:
|
||||||
|
endpoints = {
|
||||||
|
"alpha": "alpha",
|
||||||
|
"candidate": "candidate",
|
||||||
|
"prod": "console",
|
||||||
|
}
|
||||||
|
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(
|
def get_endpoint_api_key(
|
||||||
endpoint_name: str, account_api_key: str, instance: str
|
endpoint_name: str, account_api_key: str, instance: str
|
||||||
@@ -30,18 +48,13 @@ class Endpoint:
|
|||||||
Returns:
|
Returns:
|
||||||
Endpoint API key if successful, None otherwise
|
Endpoint API key if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
endpoints = {
|
|
||||||
"alpha": "alpha",
|
|
||||||
"candidate": "candidate",
|
|
||||||
"prod": "console",
|
|
||||||
}
|
|
||||||
vast_console_url = f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
|
||||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
|
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|||||||
@@ -567,7 +567,7 @@ def main():
|
|||||||
client = APIClient(
|
client = APIClient(
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
api_key=args.api_key,
|
api_key=args.api_key,
|
||||||
server_url=args.server_url,
|
server_url=Endpoint.get_autoscaler_server_url(args.instance),
|
||||||
endpoint_api_key=endpoint_api_key,
|
endpoint_api_key=endpoint_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user