Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eecefd1d52 | |||
| 52ac4c0c1a | |||
| 8804e17201 | |||
| 4016cf9a53 | |||
| e0be45f39a | |||
| be2aafdb1f | |||
| 9e369c55a5 | |||
| 69d9b7455f | |||
| 6fb610cb5b | |||
| 0bf2d04223 | |||
| 9ebf1924ea | |||
| 0ab9a13a46 |
+75
-59
@@ -8,6 +8,7 @@ import logging
|
|||||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from distutils.util import strtobool
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
||||||
@@ -55,11 +56,15 @@ class Backend:
|
|||||||
reqnum = -1
|
reqnum = -1
|
||||||
msg_history = []
|
msg_history = []
|
||||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
|
unsecured: bool = dataclasses.field(
|
||||||
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
|
self.__start_healthcheck: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||||
@@ -118,14 +123,10 @@ class Backend:
|
|||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
|
|
||||||
async def wait_for_disconnection() -> None:
|
|
||||||
while request.transport and not request.transport.is_closing():
|
|
||||||
await sleep(0.5)
|
|
||||||
|
|
||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
await wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
self.metrics._request_canceled(workload=workload)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
@@ -140,7 +141,6 @@ class Backend:
|
|||||||
else:
|
else:
|
||||||
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
|
||||||
response = await self.__call_api(handler=handler, payload=payload)
|
response = await self.__call_api(handler=handler, payload=payload)
|
||||||
status_code = response.status
|
status_code = response.status
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -152,19 +152,17 @@ class Backend:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
res = await handler.generate_client_response(request, response)
|
res = await handler.generate_client_response(request, response)
|
||||||
self.metrics._request_end(
|
self.metrics._request_success(workload=workload)
|
||||||
workload=workload,
|
|
||||||
req_response_time=time.time() - start_time,
|
|
||||||
reqnum=auth_data.reqnum,
|
|
||||||
)
|
|
||||||
return res
|
return res
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
log.debug(f"[backend] Request error: {e}")
|
log.debug(f"[backend] Request error: {e}")
|
||||||
self.metrics._request_errored(
|
self.metrics._request_errored(workload=workload)
|
||||||
workload=workload, reqnum=auth_data.reqnum
|
|
||||||
)
|
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
finally:
|
||||||
|
self.metrics._request_end(
|
||||||
|
workload=workload,
|
||||||
|
reqnum=auth_data.reqnum,
|
||||||
|
)
|
||||||
self.sem.release()
|
self.sem.release()
|
||||||
|
|
||||||
###########
|
###########
|
||||||
@@ -191,23 +189,26 @@ class Backend:
|
|||||||
if health_check_url is None:
|
if health_check_url is None:
|
||||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
return
|
return
|
||||||
await sleep(5)
|
while True:
|
||||||
try:
|
await sleep(10)
|
||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
if self.__start_healthcheck is False:
|
||||||
async with self.session.get(health_check_url) as response:
|
continue
|
||||||
if response.status == 200:
|
try:
|
||||||
log.debug("Healthcheck successful")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
elif response.status == 503:
|
async with self.session.get(health_check_url) as response:
|
||||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
if response.status == 200:
|
||||||
self.backend_errored(
|
log.debug("Healthcheck successful")
|
||||||
f"Healthcheck failed with status: {response.status}"
|
elif response.status == 503:
|
||||||
)
|
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||||
else:
|
self.backend_errored(
|
||||||
# endpoint not ready yet so bail
|
f"Healthcheck failed with status: {response.status}"
|
||||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
)
|
||||||
except Exception as e:
|
else:
|
||||||
log.debug(f"Healthcheck failed with exception: {e}")
|
# endpoint not ready yet so bail
|
||||||
self.backend_errored(str(e))
|
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Healthcheck failed with exception: {e}")
|
||||||
|
self.backend_errored(str(e))
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(
|
await gather(
|
||||||
@@ -225,6 +226,9 @@ class Backend:
|
|||||||
return await self.session.post(url=handler.endpoint, json=api_payload)
|
return await self.session.post(url=handler.endpoint, json=api_payload)
|
||||||
|
|
||||||
def __check_signature(self, auth_data: AuthData) -> bool:
|
def __check_signature(self, auth_data: AuthData) -> bool:
|
||||||
|
if self.unsecured is True:
|
||||||
|
return True
|
||||||
|
|
||||||
def verify_signature(message, signature):
|
def verify_signature(message, signature):
|
||||||
if self.pubkey is None:
|
if self.pubkey is None:
|
||||||
log.debug(f"No Public Key!")
|
log.debug(f"No Public Key!")
|
||||||
@@ -276,41 +280,52 @@ class Backend:
|
|||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
log.debug("Initial run to trigger model loading...")
|
||||||
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
|
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
|
|
||||||
max_throughput = 0
|
max_throughput = 0
|
||||||
last_throughput = 0
|
|
||||||
sum_throughput = 0
|
sum_throughput = 0
|
||||||
for run in range(self.benchmark_handler.benchmark_runs + 1):
|
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
||||||
|
|
||||||
|
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
tasks = []
|
||||||
res = await self.__call_api(
|
total_workload = 0
|
||||||
handler=self.benchmark_handler, payload=payload
|
|
||||||
)
|
for _ in range(concurrent_requests):
|
||||||
data = await res.json()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
time_elapsed = time.time() - start
|
total_workload += payload.count_workload()
|
||||||
# first run triggers one-time loading of the model which is very slow, so we skip counting it
|
tasks.append(
|
||||||
if run == 0:
|
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
continue
|
|
||||||
else:
|
|
||||||
workload = payload.count_workload()
|
|
||||||
last_throughput = workload / time_elapsed
|
|
||||||
sum_throughput += last_throughput
|
|
||||||
max_throughput = max(max_throughput, last_throughput)
|
|
||||||
log.debug(
|
|
||||||
"\n".join(
|
|
||||||
[
|
|
||||||
"#" * 60,
|
|
||||||
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
|
|
||||||
"",
|
|
||||||
f"response: {data}",
|
|
||||||
"#" * 60,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
responses = await gather(*tasks)
|
||||||
|
time_elapsed = time.time() - start
|
||||||
|
|
||||||
|
throughput = total_workload / time_elapsed
|
||||||
|
sum_throughput += throughput
|
||||||
|
max_throughput = max(max_throughput, throughput)
|
||||||
|
|
||||||
|
# Log results for debugging
|
||||||
|
log.debug(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"#" * 60,
|
||||||
|
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||||
|
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||||
|
f"Throughput: {throughput} workload/s",
|
||||||
|
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
||||||
|
"#" * 60,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||||
log.debug(
|
log.debug(
|
||||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||||
)
|
)
|
||||||
# save max_throughput so we don't have to run benchmark again on restart of cold instances
|
|
||||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||||
f.write(str(max_throughput))
|
f.write(str(max_throughput))
|
||||||
return max_throughput
|
return max_throughput
|
||||||
@@ -331,6 +346,7 @@ class Backend:
|
|||||||
await sleep(5)
|
await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
|
self.__start_healthcheck = True
|
||||||
self.metrics._model_loaded(
|
self.metrics._model_loaded(
|
||||||
max_throughput=max_throughput,
|
max_throughput=max_throughput,
|
||||||
)
|
)
|
||||||
|
|||||||
+7
-4
@@ -8,7 +8,6 @@ from aiohttp import web, ClientResponse
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -206,13 +205,13 @@ class ModelMetrics:
|
|||||||
workload_received: float
|
workload_received: float
|
||||||
workload_cancelled: float
|
workload_cancelled: float
|
||||||
workload_errored: float
|
workload_errored: float
|
||||||
workload_pending: float
|
|
||||||
# these are not
|
# these are not
|
||||||
cur_perf: float
|
workload_pending: float
|
||||||
error_msg: Optional[str]
|
error_msg: Optional[str]
|
||||||
max_throughput: float
|
max_throughput: float
|
||||||
requests_recieved: Set[int] = field(default_factory=set)
|
requests_recieved: Set[int] = field(default_factory=set)
|
||||||
requests_working: Set[int] = field(default_factory=set)
|
requests_working: Set[int] = field(default_factory=set)
|
||||||
|
last_update: float = field(default_factory=time.time)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls):
|
def empty(cls):
|
||||||
@@ -221,12 +220,15 @@ class ModelMetrics:
|
|||||||
workload_served=0.0,
|
workload_served=0.0,
|
||||||
workload_cancelled=0.0,
|
workload_cancelled=0.0,
|
||||||
workload_errored=0.0,
|
workload_errored=0.0,
|
||||||
cur_perf=0.0,
|
|
||||||
workload_received=0.0,
|
workload_received=0.0,
|
||||||
error_msg=None,
|
error_msg=None,
|
||||||
max_throughput=0.0,
|
max_throughput=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_perf(self) -> float:
|
||||||
|
return max(self.workload_served / (time.time() - self.last_update), 0.0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workload_processing(self) -> float:
|
def workload_processing(self) -> float:
|
||||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||||
@@ -240,6 +242,7 @@ class ModelMetrics:
|
|||||||
self.workload_received = 0
|
self.workload_received = 0
|
||||||
self.workload_cancelled = 0
|
self.workload_cancelled = 0
|
||||||
self.workload_errored = 0
|
self.workload_errored = 0
|
||||||
|
self.last_update = time.time()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
+11
-14
@@ -5,7 +5,6 @@ import json
|
|||||||
from asyncio import sleep
|
from asyncio import sleep
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -47,33 +46,31 @@ class Metrics:
|
|||||||
self.model_metrics.requests_recieved.add(reqnum)
|
self.model_metrics.requests_recieved.add(reqnum)
|
||||||
self.model_metrics.requests_working.add(reqnum)
|
self.model_metrics.requests_working.add(reqnum)
|
||||||
|
|
||||||
def _request_end(
|
def _request_end(self, workload: float, reqnum: int) -> None:
|
||||||
self, workload: float, req_response_time: float, reqnum: int
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
this function is called after a response from model API is received.
|
this function is called after handling of a request ends, regardless of the outcome
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_served += workload
|
|
||||||
self.model_metrics.workload_pending -= workload
|
self.model_metrics.workload_pending -= workload
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
self.model_metrics.cur_perf = workload / req_response_time
|
|
||||||
|
def _request_success(self, workload: float) -> None:
|
||||||
|
"""
|
||||||
|
this function is called after a response from model API is received and forwarded.
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_served += workload
|
||||||
self.update_pending = True
|
self.update_pending = True
|
||||||
|
|
||||||
def _request_errored(self, workload: float, reqnum: int) -> None:
|
def _request_errored(self, workload: float) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if model API returns an error
|
this function is called if model API returns an error
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_pending -= workload
|
|
||||||
self.model_metrics.workload_errored += workload
|
self.model_metrics.workload_errored += workload
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
|
||||||
|
|
||||||
def _request_canceled(self, workload: float, reqnum: int) -> None:
|
def _request_canceled(self, workload: float) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if client drops connection before model API has responded
|
this function is called if client drops connection before model API has responded
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_pending -= workload
|
|
||||||
self.model_metrics.workload_cancelled += workload
|
self.model_metrics.workload_cancelled += workload
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
|
||||||
|
|
||||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
while True:
|
while True:
|
||||||
@@ -119,7 +116,7 @@ class Metrics:
|
|||||||
|
|
||||||
def send_data(report_addr: str) -> None:
|
def send_data(report_addr: str) -> None:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
full_path = urljoin(report_addr, "/worker_status/")
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
|
|||||||
+27
-9
@@ -10,6 +10,7 @@ from collections import Counter
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from lib.data_types import AuthData, ApiPayload
|
from lib.data_types import AuthData, ApiPayload
|
||||||
@@ -53,6 +54,13 @@ test_args.add_argument(
|
|||||||
default="https://run.vast.ai",
|
default="https://run.vast.ai",
|
||||||
help="Call local autoscaler instead of prod, for dev use only",
|
help="Call local autoscaler instead of prod, for dev use only",
|
||||||
)
|
)
|
||||||
|
test_args.add_argument(
|
||||||
|
"-i",
|
||||||
|
dest="instance",
|
||||||
|
type=str,
|
||||||
|
default="prod",
|
||||||
|
help="Autoscaler shard to run the command against, default: prod",
|
||||||
|
)
|
||||||
|
|
||||||
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
||||||
|
|
||||||
@@ -70,6 +78,7 @@ class ClientState:
|
|||||||
api_key: str
|
api_key: str
|
||||||
server_url: str
|
server_url: str
|
||||||
worker_endpoint: str
|
worker_endpoint: str
|
||||||
|
instance: str
|
||||||
payload: ApiPayload
|
payload: ApiPayload
|
||||||
url: str = ""
|
url: str = ""
|
||||||
status: ClientStatus = ClientStatus.FetchEndpoint
|
status: ClientStatus = ClientStatus.FetchEndpoint
|
||||||
@@ -79,11 +88,7 @@ class ClientState:
|
|||||||
|
|
||||||
def make_call(self):
|
def make_call(self):
|
||||||
self.status = ClientStatus.FetchEndpoint
|
self.status = ClientStatus.FetchEndpoint
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
if not self.api_key:
|
||||||
endpoint_name=self.endpoint_group_name,
|
|
||||||
account_api_key=self.api_key,
|
|
||||||
)
|
|
||||||
if not endpoint_api_key:
|
|
||||||
self.as_error.append(
|
self.as_error.append(
|
||||||
f"Endpoint {self.endpoint_group_name} not found for API key",
|
f"Endpoint {self.endpoint_group_name} not found for API key",
|
||||||
)
|
)
|
||||||
@@ -91,12 +96,14 @@ class ClientState:
|
|||||||
return
|
return
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint": self.endpoint_group_name,
|
"endpoint": self.endpoint_group_name,
|
||||||
"api_key": endpoint_api_key,
|
"api_key": self.api_key,
|
||||||
"cost": self.payload.count_workload(),
|
"cost": self.payload.count_workload(),
|
||||||
}
|
}
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(self.server_url, "/route/"),
|
urljoin(self.server_url, "/route/"),
|
||||||
json=route_payload,
|
json=route_payload,
|
||||||
|
headers=headers,
|
||||||
timeout=4,
|
timeout=4,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@@ -114,9 +121,11 @@ class ClientState:
|
|||||||
self.url = worker_address
|
self.url = worker_address
|
||||||
url = urljoin(worker_address, self.worker_endpoint)
|
url = urljoin(worker_address, self.worker_endpoint)
|
||||||
self.status = ClientStatus.Generating
|
self.status = ClientStatus.Generating
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
self.infer_error.append(
|
self.infer_error.append(
|
||||||
@@ -135,6 +144,7 @@ class ClientState:
|
|||||||
try:
|
try:
|
||||||
self.make_call()
|
self.make_call()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
self.status = ClientStatus.Error
|
self.status = ClientStatus.Error
|
||||||
_ = e
|
_ = e
|
||||||
self.conn_errors[self.url] += 1
|
self.conn_errors[self.url] += 1
|
||||||
@@ -226,6 +236,7 @@ def run_test(
|
|||||||
server_url: str,
|
server_url: str,
|
||||||
worker_endpoint: str,
|
worker_endpoint: str,
|
||||||
payload_cls: Type[ApiPayload],
|
payload_cls: Type[ApiPayload],
|
||||||
|
instance: str,
|
||||||
):
|
):
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
@@ -234,8 +245,7 @@ def run_test(
|
|||||||
print_thread.daemon = True # makes threads get killed on program exit
|
print_thread.daemon = True # makes threads get killed on program exit
|
||||||
print_thread.start()
|
print_thread.start()
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=endpoint_group_name,
|
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
||||||
account_api_key=api_key,
|
|
||||||
)
|
)
|
||||||
if not endpoint_api_key:
|
if not endpoint_api_key:
|
||||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||||
@@ -248,6 +258,7 @@ def run_test(
|
|||||||
server_url=server_url,
|
server_url=server_url,
|
||||||
worker_endpoint=worker_endpoint,
|
worker_endpoint=worker_endpoint,
|
||||||
payload=payload_cls.for_test(),
|
payload=payload_cls.for_test(),
|
||||||
|
instance=instance,
|
||||||
)
|
)
|
||||||
clients.append(client)
|
clients.append(client)
|
||||||
thread = threading.Thread(target=client.simulate_user, args=())
|
thread = threading.Thread(target=client.simulate_user, args=())
|
||||||
@@ -281,12 +292,19 @@ def test_load_cmd(
|
|||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
if hasattr(args, "comfy_model"):
|
if hasattr(args, "comfy_model"):
|
||||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||||
|
server_url = dict(
|
||||||
|
prod="https://run.vast.ai",
|
||||||
|
alpha="https://run-alpha.vast.ai",
|
||||||
|
candidate="https://run-candidate.vast.ai",
|
||||||
|
local="http://localhost:8080",
|
||||||
|
)[args.instance]
|
||||||
run_test(
|
run_test(
|
||||||
num_requests=args.num_requests,
|
num_requests=args.num_requests,
|
||||||
requests_per_second=args.requests_per_second,
|
requests_per_second=args.requests_per_second,
|
||||||
api_key=args.api_key,
|
api_key=args.api_key,
|
||||||
server_url=args.server_url,
|
server_url=server_url,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
worker_endpoint=endpoint,
|
worker_endpoint=endpoint,
|
||||||
payload_cls=payload_cls,
|
payload_cls=payload_cls,
|
||||||
|
instance=args.instance,
|
||||||
)
|
)
|
||||||
|
|||||||
+3
-2
@@ -1,4 +1,4 @@
|
|||||||
aiohttp~=3.11
|
aiohttp[speedups]==3.10.1
|
||||||
anyio~=4.4
|
anyio~=4.4
|
||||||
lib~=4.0
|
lib~=4.0
|
||||||
nltk~=3.9
|
nltk~=3.9
|
||||||
@@ -6,4 +6,5 @@ psutil~=6.0
|
|||||||
pycryptodome~=3.20
|
pycryptodome~=3.20
|
||||||
Requests~=2.32
|
Requests~=2.32
|
||||||
transformers~=4.52
|
transformers~=4.52
|
||||||
utils~=1.0
|
utils==1.0.*
|
||||||
|
hf_transfer>=0.1.9
|
||||||
|
|||||||
+16
-14
@@ -46,17 +46,19 @@ env | grep _ >> /etc/environment;
|
|||||||
|
|
||||||
if [ ! -d "$ENV_PATH" ]
|
if [ ! -d "$ENV_PATH" ]
|
||||||
then
|
then
|
||||||
apt install -y python3.10-venv
|
|
||||||
echo "setting up venv"
|
echo "setting up venv"
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
source ~/.local/bin/env
|
||||||
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
||||||
|
|
||||||
python3 -m venv "$WORKSPACE_DIR/worker-env"
|
uv venv --managed-python "$WORKSPACE_DIR/worker-env" -p 3.10
|
||||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||||
|
|
||||||
pip install -r vast-pyworker/requirements.txt
|
uv pip install -r vast-pyworker/requirements.txt
|
||||||
|
|
||||||
touch ~/.no_auto_tmux
|
touch ~/.no_auto_tmux
|
||||||
else
|
else
|
||||||
|
source ~/.local/bin/env
|
||||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||||
echo "environment activated"
|
echo "environment activated"
|
||||||
echo "venv: $VIRTUAL_ENV"
|
echo "venv: $VIRTUAL_ENV"
|
||||||
@@ -87,23 +89,23 @@ if [ "$USE_SSL" = true ]; then
|
|||||||
IP.1 = 0.0.0.0
|
IP.1 = 0.0.0.0
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
||||||
-nodes \
|
-nodes \
|
||||||
-sha256 \
|
-sha256 \
|
||||||
-keyout /etc/instance.key \
|
-keyout /etc/instance.key \
|
||||||
-out /etc/instance.csr \
|
-out /etc/instance.csr \
|
||||||
-config /etc/openssl-san.cnf
|
-config /etc/openssl-san.cnf
|
||||||
|
|
||||||
curl --header 'Content-Type: application/octet-stream' \
|
curl --header 'Content-Type: application/octet-stream' \
|
||||||
--data-binary @//etc/instance.csr \
|
--data-binary @//etc/instance.csr \
|
||||||
-X \
|
-X \
|
||||||
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export REPORT_ADDR WORKER_PORT USE_SSL
|
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
||||||
|
|
||||||
cd "$SERVER_DIR"
|
cd "$SERVER_DIR"
|
||||||
|
|
||||||
|
|||||||
+12
-3
@@ -17,7 +17,9 @@ class Endpoint:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]:
|
def get_endpoint_api_key(
|
||||||
|
endpoint_name: str, account_api_key: str, instance: str
|
||||||
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Fetch endpoint API key from VastAI console following the healthcheck pattern.
|
Fetch endpoint API key from VastAI console following the healthcheck pattern.
|
||||||
|
|
||||||
@@ -28,12 +30,19 @@ class Endpoint:
|
|||||||
Returns:
|
Returns:
|
||||||
Endpoint API key if successful, None otherwise
|
Endpoint API key if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
|
endpoints = {
|
||||||
|
"alpha": "alpha",
|
||||||
|
"candidate": "candidate",
|
||||||
|
"prod": "console",
|
||||||
|
}
|
||||||
|
vast_console_url = f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
||||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||||
response = requests.get(vast_console_url, headers=headers)
|
response = requests.get(
|
||||||
|
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
|
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
import tempfile
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_cert_file_path():
|
||||||
|
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||||
|
response = requests.get(cert_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
# Use a temporary file that is not deleted on close
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
return f.name
|
||||||
@@ -5,6 +5,7 @@ import requests
|
|||||||
|
|
||||||
from lib.test_utils import print_truncate_res
|
from lib.test_utils import print_truncate_res
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||||
@@ -51,6 +52,7 @@ def call_default_workflow(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
@@ -153,6 +156,7 @@ if __name__ == "__main__":
|
|||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=args.endpoint_group_name,
|
endpoint_name=args.endpoint_group_name,
|
||||||
account_api_key=args.api_key,
|
account_api_key=args.api_key,
|
||||||
|
instance=args.instance,
|
||||||
)
|
)
|
||||||
if endpoint_api_key:
|
if endpoint_api_key:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
# OpenAI Compatible PyWorker
|
||||||
|
|
||||||
|
This is the base PyWorker for OpenAI compatible inference servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||||
|
|
||||||
|
## Instance Setup
|
||||||
|
|
||||||
|
1. Pick a template
|
||||||
|
|
||||||
|
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||||
|
|
||||||
|
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||||
|
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||||
|
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||||
|
|
||||||
|
|
||||||
|
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||||
|
|
||||||
|
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
|
## Client Setup (Demo)
|
||||||
|
|
||||||
|
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/vast-ai/pyworker
|
||||||
|
cd pyworker
|
||||||
|
pip install uv
|
||||||
|
uv venv -p 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using the Test Client
|
||||||
|
|
||||||
|
Several examples have been provided in the client to help you get started with your own implementation.
|
||||||
|
|
||||||
|
### Completions
|
||||||
|
|
||||||
|
Call to `/v1/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completion (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completion (streaming)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with streaming response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Use (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with tool and json response.
|
||||||
|
|
||||||
|
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Interactive Chat (streaming)
|
||||||
|
|
||||||
|
Interactive session with calls to `/v1/chat/completions`.
|
||||||
|
|
||||||
|
Type `clear` to clear the chat history or `quit` to exit.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
|
||||||
|
|
||||||
|
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
|
||||||
|
|
||||||
|
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
|
||||||
|
|
||||||
|
| Variable | Default Value | Used For |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
|
||||||
|
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
|
||||||
|
|
||||||
|
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
We have provided a demonstration client to help you implement this template into your own infrastructure
|
||||||
|
|
||||||
|
### Client Setup
|
||||||
|
|
||||||
|
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/vast-ai/pyworker
|
||||||
|
cd pyworker
|
||||||
|
pip install uv
|
||||||
|
uv venv -p 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Completions
|
||||||
|
|
||||||
|
Call to `/v1/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completion (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completion (streaming)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with streaming response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Use (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with tool and json response.
|
||||||
|
|
||||||
|
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Interactive Chat (streaming)
|
||||||
|
|
||||||
|
Interactive session with calls to `/v1/chat/completions`.
|
||||||
|
|
||||||
|
Type `clear` to clear the chat history or `quit` to exit.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||||
|
```
|
||||||
@@ -0,0 +1,599 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||||
|
import requests
|
||||||
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||||
|
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
|
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
||||||
|
|
||||||
|
|
||||||
|
class APIClient:
|
||||||
|
"""Lightweight client focused solely on API communication"""
|
||||||
|
|
||||||
|
# Remove the generic WORKER_ENDPOINT since we're now going direct
|
||||||
|
DEFAULT_COST = 100
|
||||||
|
DEFAULT_TIMEOUT = 4
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint_group_name: str,
|
||||||
|
api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
endpoint_api_key: str,
|
||||||
|
):
|
||||||
|
self.endpoint_group_name = endpoint_group_name
|
||||||
|
self.api_key = api_key
|
||||||
|
self.server_url = server_url
|
||||||
|
self.endpoint_api_key = endpoint_api_key
|
||||||
|
|
||||||
|
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||||
|
"""Get worker URL and auth data from routing service"""
|
||||||
|
if not self.endpoint_api_key:
|
||||||
|
raise ValueError("No valid endpoint API key available")
|
||||||
|
|
||||||
|
route_payload = {
|
||||||
|
"endpoint": self.endpoint_group_name,
|
||||||
|
"api_key": self.endpoint_api_key,
|
||||||
|
"cost": cost,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(self.server_url, "/route/"),
|
||||||
|
json=route_payload,
|
||||||
|
timeout=self.DEFAULT_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Create auth data from routing response"""
|
||||||
|
return {
|
||||||
|
"signature": message["signature"],
|
||||||
|
"cost": message["cost"],
|
||||||
|
"endpoint": message["endpoint"],
|
||||||
|
"reqnum": message["reqnum"],
|
||||||
|
"url": message["url"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _make_request(
|
||||||
|
self,
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
endpoint: str,
|
||||||
|
method: str = "POST",
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||||
|
"""Make request directly to the specific worker endpoint"""
|
||||||
|
# Get worker URL and auth data
|
||||||
|
cost = payload.get("max_tokens", self.DEFAULT_COST)
|
||||||
|
message = self._get_worker_url(cost=cost)
|
||||||
|
worker_url = message["url"]
|
||||||
|
auth_data = self._create_auth_data(message)
|
||||||
|
|
||||||
|
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
|
||||||
|
|
||||||
|
url = urljoin(worker_url, endpoint)
|
||||||
|
log.debug(f"Making direct request to: {url}")
|
||||||
|
log.debug(f"Payload: {req_data}")
|
||||||
|
|
||||||
|
# Make the request using the specified method
|
||||||
|
if method.upper() == "POST":
|
||||||
|
response = requests.post(
|
||||||
|
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
|
elif method.upper() == "GET":
|
||||||
|
response = requests.get(
|
||||||
|
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_streaming_response(response)
|
||||||
|
else:
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
|
||||||
|
"""Handle streaming response and yield tokens"""
|
||||||
|
try:
|
||||||
|
for line in response.iter_lines(decode_unicode=True):
|
||||||
|
if line:
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
yield data # Yield the full chunk
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error handling streaming response: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def call_completions(
|
||||||
|
self, config: CompletionConfig
|
||||||
|
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||||
|
payload = config.to_dict()
|
||||||
|
|
||||||
|
return self._make_request(
|
||||||
|
payload=payload, endpoint="/v1/completions", stream=config.stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_chat_completions(
|
||||||
|
self, config: ChatCompletionConfig
|
||||||
|
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||||
|
payload = config.to_dict()
|
||||||
|
|
||||||
|
return self._make_request(
|
||||||
|
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolManager:
|
||||||
|
"""Handles tool definitions and execution"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_files() -> str:
|
||||||
|
"""Execute ls on current directory"""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["ls", "-la", "."], capture_output=True, text=True, timeout=10
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
return result.stdout
|
||||||
|
else:
|
||||||
|
return f"Error: {result.stderr}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error running ls: {e}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||||
|
"""Get the ls tool definition"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "list_files",
|
||||||
|
"description": "List files and directories in the cwd",
|
||||||
|
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||||
|
"""Execute a tool call and return the result"""
|
||||||
|
function_name = tool_call["function"]["name"]
|
||||||
|
|
||||||
|
if function_name == "list_files":
|
||||||
|
return self.list_files()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tool function: {function_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class APIDemo:
|
||||||
|
"""Demo and testing functionality for the API client"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
|
||||||
|
):
|
||||||
|
self.client = client
|
||||||
|
self.model = model
|
||||||
|
self.tool_manager = tool_manager or ToolManager()
|
||||||
|
|
||||||
|
def handle_streaming_response(
|
||||||
|
self, response_stream, show_reasoning: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Handle streaming chat response and display all output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
|
reasoning_content = ""
|
||||||
|
reasoning_started = False
|
||||||
|
content_started = False
|
||||||
|
|
||||||
|
for chunk in response_stream:
|
||||||
|
# Normalize the chunk
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if chunk.startswith("data: "):
|
||||||
|
chunk = chunk[6:].strip()
|
||||||
|
if chunk in ["[DONE]", ""]:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed_chunk = json.loads(chunk)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, dict):
|
||||||
|
parsed_chunk = chunk
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse delta from the chunk
|
||||||
|
choices = parsed_chunk.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = choices[0].get("delta", {})
|
||||||
|
reasoning_token = delta.get("reasoning_content", "")
|
||||||
|
content_token = delta.get("content", "")
|
||||||
|
|
||||||
|
# Print reasoning token if applicable
|
||||||
|
if show_reasoning and reasoning_token:
|
||||||
|
if not reasoning_started:
|
||||||
|
print("\n🧠 Reasoning: ", end="", flush=True)
|
||||||
|
reasoning_started = True
|
||||||
|
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
|
||||||
|
reasoning_content += reasoning_token
|
||||||
|
|
||||||
|
# Print content token
|
||||||
|
if content_token:
|
||||||
|
if not content_started:
|
||||||
|
if show_reasoning and reasoning_started:
|
||||||
|
print(f"\n💬 Response: ", end="", flush=True)
|
||||||
|
else:
|
||||||
|
print("Assistant: ", end="", flush=True)
|
||||||
|
content_started = True
|
||||||
|
print(content_token, end="", flush=True)
|
||||||
|
full_response += content_token
|
||||||
|
|
||||||
|
print() # Ensure newline after response
|
||||||
|
|
||||||
|
if show_reasoning:
|
||||||
|
if reasoning_started or content_started:
|
||||||
|
print("\nStreaming completed.")
|
||||||
|
if reasoning_started:
|
||||||
|
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||||
|
if content_started:
|
||||||
|
print(f"Response tokens: {len(full_response.split())}")
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
def test_tool_support(self) -> bool:
|
||||||
|
"""Test if the endpoint supports function calling"""
|
||||||
|
log.debug("Testing endpoint tool calling support...")
|
||||||
|
|
||||||
|
# Try a simple request with minimal tools to test support
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
minimal_tool = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "test_function", "description": "Test function"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
config = ChatCompletionConfig(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
tools=minimal_tool,
|
||||||
|
tool_choice="none", # Don't actually call the tool
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.call_chat_completions(config)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error: Endpoint does not support tool calling: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def demo_completions(self) -> None:
|
||||||
|
"""Demo: test basic completions endpoint"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("COMPLETIONS DEMO")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
config = CompletionConfig(
|
||||||
|
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
|
||||||
|
)
|
||||||
|
response = self.client.call_completions(config)
|
||||||
|
|
||||||
|
if isinstance(response, dict):
|
||||||
|
print("\nResponse:")
|
||||||
|
print(json.dumps(response, indent=2))
|
||||||
|
else:
|
||||||
|
log.error("Unexpected response format")
|
||||||
|
|
||||||
|
def demo_chat(self, use_streaming: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Demo: test chat completions endpoint with optional streaming
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print(
|
||||||
|
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
|
||||||
|
)
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
config = ChatCompletionConfig(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": "user", "content": CHAT_PROMPT}],
|
||||||
|
stream=use_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Testing chat completions with model '{self.model}'...")
|
||||||
|
response = self.client.call_chat_completions(config)
|
||||||
|
|
||||||
|
if use_streaming:
|
||||||
|
try:
|
||||||
|
self.handle_streaming_response(response, show_reasoning=True)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"\nError during streaming: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
if isinstance(response, dict):
|
||||||
|
choice = response.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
content = message.get("content", "")
|
||||||
|
reasoning = message.get("reasoning_content", "") or message.get(
|
||||||
|
"reasoning", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if reasoning:
|
||||||
|
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||||
|
|
||||||
|
print(f"\n💬 Assistant: {content}")
|
||||||
|
print(f"\nFull Response:")
|
||||||
|
print(json.dumps(response, indent=2))
|
||||||
|
else:
|
||||||
|
log.error("Unexpected response format")
|
||||||
|
|
||||||
|
def demo_ls_tool(self) -> None:
|
||||||
|
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("TOOL USE DEMO: List Directory Contents")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Test if tools are supported first
|
||||||
|
if not self.test_tool_support():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Request with tool available
|
||||||
|
messages = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||||
|
|
||||||
|
config = ChatCompletionConfig(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=self.tool_manager.get_ls_tool_definition(),
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Making initial request with tool using model '{self.model}'...")
|
||||||
|
response = self.client.call_chat_completions(config)
|
||||||
|
|
||||||
|
if not isinstance(response, dict):
|
||||||
|
raise ValueError("Expected dict response for tool use")
|
||||||
|
|
||||||
|
choice = response.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
|
||||||
|
print(f"Assistant response: {message.get('content', 'No content')}")
|
||||||
|
|
||||||
|
# Check for tool calls
|
||||||
|
tool_calls = message.get("tool_calls")
|
||||||
|
if not tool_calls:
|
||||||
|
raise ValueError(
|
||||||
|
"No tool calls made - model may not support function calling"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Tool calls detected: {len(tool_calls)}")
|
||||||
|
|
||||||
|
# Execute the tool call
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
function_name = tool_call["function"]["name"]
|
||||||
|
print(f"Executing tool: {function_name}")
|
||||||
|
|
||||||
|
tool_result = self.tool_manager.execute_tool_call(tool_call)
|
||||||
|
print(f"Tool result:\n{tool_result}")
|
||||||
|
|
||||||
|
# Add tool result and continue conversation
|
||||||
|
messages.append(message) # Add assistant's message with tool call
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call["id"],
|
||||||
|
"content": tool_result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get final response
|
||||||
|
final_config = ChatCompletionConfig(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=self.tool_manager.get_ls_tool_definition(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Getting final response...")
|
||||||
|
final_response = self.client.call_chat_completions(final_config)
|
||||||
|
|
||||||
|
if isinstance(final_response, dict):
|
||||||
|
final_choice = final_response.get("choices", [{}])[0]
|
||||||
|
final_message = final_choice.get("message", {})
|
||||||
|
final_content = final_message.get("content", "")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("FINAL LLM ANALYSIS:")
|
||||||
|
print("=" * 60)
|
||||||
|
print(final_content)
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def interactive_chat(self) -> None:
|
||||||
|
"""Interactive chat session with streaming"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("INTERACTIVE STREAMING CHAT")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Using model: {self.model}")
|
||||||
|
print("Type 'quit' to exit, 'clear' to clear history")
|
||||||
|
print()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = input("You: ").strip()
|
||||||
|
|
||||||
|
if user_input.lower() == "quit":
|
||||||
|
print("👋 Goodbye!")
|
||||||
|
break
|
||||||
|
elif user_input.lower() == "clear":
|
||||||
|
messages = []
|
||||||
|
print("Chat history cleared")
|
||||||
|
continue
|
||||||
|
elif not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
config = ChatCompletionConfig(
|
||||||
|
model=self.model, messages=messages, stream=True, temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Assistant: ", end="", flush=True)
|
||||||
|
|
||||||
|
response = self.client.call_chat_completions(config)
|
||||||
|
assistant_content = self.handle_streaming_response(
|
||||||
|
response, show_reasoning=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add assistant response to conversation history
|
||||||
|
messages.append({"role": "assistant", "content": assistant_content})
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Chat interrupted. Goodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"\nError: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function with CLI switches for different tests"""
|
||||||
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
|
# Add mandatory model argument
|
||||||
|
test_args.add_argument(
|
||||||
|
"--model", required=True, help="Model to use for requests (required)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add test mode arguments
|
||||||
|
test_args.add_argument(
|
||||||
|
"--completion", action="store_true", help="Test completions endpoint"
|
||||||
|
)
|
||||||
|
test_args.add_argument(
|
||||||
|
"--chat",
|
||||||
|
action="store_true",
|
||||||
|
help="Test chat completions endpoint (non-streaming)",
|
||||||
|
)
|
||||||
|
test_args.add_argument(
|
||||||
|
"--chat-stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Test chat completions endpoint with streaming",
|
||||||
|
)
|
||||||
|
test_args.add_argument(
|
||||||
|
"--tools",
|
||||||
|
action="store_true",
|
||||||
|
help="Test function calling with ls tool (non-streaming)",
|
||||||
|
)
|
||||||
|
test_args.add_argument(
|
||||||
|
"--interactive",
|
||||||
|
action="store_true",
|
||||||
|
help="Start interactive streaming chat session",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = test_args.parse_args()
|
||||||
|
|
||||||
|
# Check that only one test mode is selected
|
||||||
|
test_modes = [
|
||||||
|
args.completion,
|
||||||
|
args.chat,
|
||||||
|
args.chat_stream,
|
||||||
|
args.tools,
|
||||||
|
args.interactive,
|
||||||
|
]
|
||||||
|
selected_count = sum(test_modes)
|
||||||
|
|
||||||
|
if selected_count == 0:
|
||||||
|
print("Please specify exactly one test mode:")
|
||||||
|
print(" --completion : Test completions endpoint")
|
||||||
|
print(" --chat : Test chat completions endpoint (non-streaming)")
|
||||||
|
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||||
|
print(" --tools : Test function calling with ls tool (non-streaming)")
|
||||||
|
print(" --interactive : Start interactive streaming chat session")
|
||||||
|
print(
|
||||||
|
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
elif selected_count > 1:
|
||||||
|
print("Please specify exactly one test mode")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
instance=args.instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not endpoint_api_key:
|
||||||
|
log.error(
|
||||||
|
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Create the core API client
|
||||||
|
client = APIClient(
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
api_key=args.api_key,
|
||||||
|
server_url=args.server_url,
|
||||||
|
endpoint_api_key=endpoint_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tool manager and demo (passing the model parameter)
|
||||||
|
tool_manager = ToolManager()
|
||||||
|
demo = APIDemo(client, args.model, tool_manager)
|
||||||
|
|
||||||
|
print(f"Using model: {args.model}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Run the selected test
|
||||||
|
if args.completion:
|
||||||
|
demo.demo_completions()
|
||||||
|
elif args.chat:
|
||||||
|
demo.demo_chat(use_streaming=False)
|
||||||
|
elif args.chat_stream:
|
||||||
|
demo.demo_chat(use_streaming=True)
|
||||||
|
elif args.tools:
|
||||||
|
demo.demo_ls_tool()
|
||||||
|
elif args.interactive:
|
||||||
|
demo.interactive_chat()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error during test: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field, fields, is_dataclass
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class SerializableDataclass:
|
||||||
|
def _serialize_recursive(self, obj: Any) -> Any:
|
||||||
|
if is_dataclass(obj):
|
||||||
|
return {
|
||||||
|
field.name: self._serialize_recursive(getattr(obj, field.name))
|
||||||
|
for field in fields(obj)
|
||||||
|
}
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return [self._serialize_recursive(item) for item in obj]
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
return [self._serialize_recursive(item) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return self._serialize_recursive(self)
|
||||||
|
|
||||||
|
def to_json(self, indent: int = 2) -> str:
|
||||||
|
return json.dumps(self.to_dict(), indent=indent)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompletionConfig(SerializableDataclass):
|
||||||
|
"""Configuration for completion requests"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
prompt: str = "Hello"
|
||||||
|
max_tokens: int = 256
|
||||||
|
temperature: float = 0.7
|
||||||
|
top_k: int = 20
|
||||||
|
top_p: float = 0.4
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatCompletionConfig(SerializableDataclass):
|
||||||
|
"""Configuration for chat completion requests"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
messages: list = field(default_factory=list)
|
||||||
|
max_tokens: int = 2096
|
||||||
|
temperature: float = 0.7
|
||||||
|
top_k: int = 20
|
||||||
|
top_p: float = 0.4
|
||||||
|
stream: bool = False
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
|
||||||
|
tool_choice: str = "auto"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.messages is None:
|
||||||
|
self.messages = [{"role": "user", "content": "Hello"}]
|
||||||
@@ -0,0 +1,182 @@
|
|||||||
|
import os, json, random
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
||||||
|
from typing import Union, Type, Dict, Any, Optional
|
||||||
|
from aiohttp import web, ClientResponse
|
||||||
|
import nltk
|
||||||
|
import logging
|
||||||
|
|
||||||
|
nltk.download("words")
|
||||||
|
WORD_LIST = nltk.corpus.words.words()
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Generic dataclass accepts any dictionary in input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenericData(ApiPayload, ABC):
|
||||||
|
input: Dict[str, Any]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
||||||
|
return cls(input=data["input"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
# Validate required parameters
|
||||||
|
required_params = ["input"]
|
||||||
|
for param in required_params:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create clean data dict and delegate to from_dict
|
||||||
|
clean_data = {"input": json_msg["input"]}
|
||||||
|
|
||||||
|
return cls.from_dict(clean_data)
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, JsonDataException) as e:
|
||||||
|
errors["parameters"] = str(e)
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def for_test(cls) -> "GenericData":
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
return self.input
|
||||||
|
|
||||||
|
def count_workload(self) -> int:
|
||||||
|
return self.input.get("max_tokens", 0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
return os.environ.get("MODEL_HEALTH_ENDPOINT")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[GenericData]:
|
||||||
|
return GenericData
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def make_benchmark_payload(self) -> GenericData:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
match model_response.status:
|
||||||
|
case 200:
|
||||||
|
# Check if the response is actually streaming based on response headers/content-type
|
||||||
|
is_streaming_response = (
|
||||||
|
model_response.content_type == "text/event-stream"
|
||||||
|
or model_response.content_type == "application/x-ndjson"
|
||||||
|
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||||
|
or "stream" in model_response.content_type.lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_streaming_response:
|
||||||
|
log.debug("Detected streaming response...")
|
||||||
|
res = web.StreamResponse()
|
||||||
|
res.content_type = model_response.content_type
|
||||||
|
await res.prepare(client_request)
|
||||||
|
async for chunk in model_response.content:
|
||||||
|
await res.write(chunk)
|
||||||
|
await res.write_eof()
|
||||||
|
log.debug("Done streaming response")
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
log.debug("Detected non-streaming response...")
|
||||||
|
content = await model_response.read()
|
||||||
|
return web.Response(
|
||||||
|
body=content,
|
||||||
|
status=200,
|
||||||
|
content_type=model_response.content_type,
|
||||||
|
)
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompletionsData(GenericData):
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls) -> "CompletionsData":
|
||||||
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
model = os.environ.get("MODEL_NAME")
|
||||||
|
if not model:
|
||||||
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
|
test_input = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
|
}
|
||||||
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompletionsHandler(GenericHandler):
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/v1/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[CompletionsData]:
|
||||||
|
return CompletionsData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> CompletionsData:
|
||||||
|
return CompletionsData.for_test()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatCompletionsData(GenericData):
|
||||||
|
"""Chat completions-specific data implementation"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls) -> "ChatCompletionsData":
|
||||||
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
model = os.environ.get("MODEL_NAME")
|
||||||
|
if not model:
|
||||||
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
|
# Chat completions use messages format instead of prompt
|
||||||
|
test_input = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
|
}
|
||||||
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatCompletionsHandler(GenericHandler):
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/v1/chat/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[ChatCompletionsData]:
|
||||||
|
return ChatCompletionsData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> ChatCompletionsData:
|
||||||
|
return ChatCompletionsData.for_test()
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
|
||||||
|
from aiohttp import web
|
||||||
|
from lib.backend import Backend, LogAction
|
||||||
|
from lib.server import start_server
|
||||||
|
|
||||||
|
# This line indicates that the inference server is listening
|
||||||
|
MODEL_SERVER_START_LOG_MSG = [
|
||||||
|
"Application startup complete.", # vLLM
|
||||||
|
"llama runner started", # Ollama
|
||||||
|
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||||
|
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||||
|
]
|
||||||
|
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
|
"INFO exited: vllm", # vLLM
|
||||||
|
"RuntimeError: Engine", # vLLM
|
||||||
|
"Error: pull model manifest:", # Ollama
|
||||||
|
"stalled; retrying", # Ollama
|
||||||
|
"Error: WebserverFailed", # TGI
|
||||||
|
"Error: DownloadError", # TGI
|
||||||
|
"Error: ShardCannotStart", # TGI
|
||||||
|
]
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
backend = Backend(
|
||||||
|
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||||
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
|
allow_parallel_requests=True,
|
||||||
|
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
|
log_actions=[
|
||||||
|
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||||
|
(LogAction.Info, '"message":"Download'),
|
||||||
|
*[
|
||||||
|
(LogAction.ModelError, error_msg)
|
||||||
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ping(_):
|
||||||
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
||||||
|
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
||||||
|
web.get("/ping", handle_ping),
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_server(backend, routes)
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from lib.test_utils import test_load_cmd, test_args
|
||||||
|
from .data_types.server import CompletionsData
|
||||||
|
import os
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/v1/completions"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Check if MODEL_NAME environment variable is set
|
||||||
|
model_name_set = os.environ.get("MODEL_NAME") is not None
|
||||||
|
|
||||||
|
# Add model argument - required only if MODEL_NAME is not set
|
||||||
|
test_args.add_argument(
|
||||||
|
"--model",
|
||||||
|
dest="model",
|
||||||
|
required=not model_name_set,
|
||||||
|
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse known args to get model early, before test_load_cmd adds its args
|
||||||
|
known_args, _ = test_args.parse_known_args()
|
||||||
|
|
||||||
|
# Set environment variable if model was provided
|
||||||
|
if hasattr(known_args, "model") and known_args.model:
|
||||||
|
os.environ["MODEL_NAME"] = known_args.model
|
||||||
|
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||||
|
|
||||||
|
# Now call test_load_cmd normally - it will add its own args and re-parse
|
||||||
|
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
|||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
print(f"url: {url}")
|
print(f"url: {url}")
|
||||||
response = requests.post(url, json=req_data)
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
res = response.json()
|
res = response.json()
|
||||||
print(res)
|
print(res)
|
||||||
@@ -100,6 +105,7 @@ if __name__ == "__main__":
|
|||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=args.endpoint_group_name,
|
endpoint_name=args.endpoint_group_name,
|
||||||
account_api_key=args.api_key,
|
account_api_key=args.api_key,
|
||||||
|
instance=args.instance,
|
||||||
)
|
)
|
||||||
if endpoint_api_key:
|
if endpoint_api_key:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user