Compare commits

...

9 Commits

Author SHA1 Message Date
Nader Arbabian d3be9fe7db redo metrics tracking for requests, fixes bug wherere some requests were marked as pending, even though they had finished 2025-07-30 18:56:51 -07:00
Rob Ballantyne e0be45f39a Addresses breaking change in core pyworker (#22)
* Addresses breaking change in test_utils.py

Endpoint.get_endpoint_api_key() now requires instance

Moves the call to this function out of the APIClient and into main

* Ensure make_benchmark_payload has a value to calculate the workload

---------

Co-authored-by: Nader Arbabian <nader@vast.ai>
2025-07-18 16:11:10 -07:00
Nader Arbabian be2aafdb1f fix pyright errors + revert to old way of handling cancelled api requests (#23) 2025-07-17 16:59:06 -07:00
Rob Ballantyne 9e369c55a5 Ensure venv creation where python is unavailable (#21) 2025-07-17 09:59:35 -07:00
Rob Ballantyne 69d9b7455f OpenAI compatible worker (#19)
Adds initial support for OpenAI compatible inference servers

Available endpoints:

- `/v1/completions`
- `/v1/chat/completions`
2025-07-16 09:46:26 +01:00
Nader Arbabian 6fb610cb5b fix pyworker miscounting active connections (#20)
* fix pyworker miscounting active connections

* clean up some issues

* add option to skip auth
2025-07-15 15:33:27 -07:00
Nader Arbabian 0bf2d04223 stop using urljoin for worker_status endpoint 2025-06-17 23:09:45 -07:00
Nader Arbabian 9ebf1924ea don't healthcheck endpoints until model is loaded and benchmarks have run 2025-06-11 15:26:50 -07:00
Nader Arbabian 0ab9a13a46 update tokenizers deps 2025-06-10 17:56:06 -07:00
18 changed files with 1185 additions and 77 deletions
+37 -32
View File
@@ -8,6 +8,7 @@ import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
@@ -55,11 +56,15 @@ class Backend:
reqnum = -1 reqnum = -1
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
@@ -118,14 +123,10 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
async def wait_for_disconnection() -> None:
while request.transport and not request.transport.is_closing():
await sleep(0.5)
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> web.Response:
await wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum) self.metrics._request_canceled(workload=workload)
return web.Response(status=500) return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
@@ -140,7 +141,6 @@ class Backend:
else: else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}") log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try: try:
start_time = time.time()
response = await self.__call_api(handler=handler, payload=payload) response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status status_code = response.status
log.debug( log.debug(
@@ -152,19 +152,17 @@ class Backend:
) )
) )
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_end( self.metrics._request_success(workload=workload)
workload=workload,
req_response_time=time.time() - start_time,
reqnum=auth_data.reqnum,
)
return res return res
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored( self.metrics._request_errored(workload=workload)
workload=workload, reqnum=auth_data.reqnum
)
return web.Response(status=500) return web.Response(status=500)
finally: finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release() self.sem.release()
########### ###########
@@ -191,23 +189,26 @@ class Backend:
if health_check_url is None: if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck") log.debug("No healthcheck endpoint defined, skipping healthcheck")
return return
await sleep(5) while True:
try: await sleep(10)
log.debug(f"Performing healthcheck on {health_check_url}") if self.__start_healthcheck is False:
async with self.session.get(health_check_url) as response: continue
if response.status == 200: try:
log.debug("Healthcheck successful") log.debug(f"Performing healthcheck on {health_check_url}")
elif response.status == 503: async with self.session.get(health_check_url) as response:
log.debug(f"Healthcheck failed with status: {response.status}") if response.status == 200:
self.backend_errored( log.debug("Healthcheck successful")
f"Healthcheck failed with status: {response.status}" elif response.status == 503:
) log.debug(f"Healthcheck failed with status: {response.status}")
else: self.backend_errored(
# endpoint not ready yet so bail f"Healthcheck failed with status: {response.status}"
log.debug(f"Healthcheck Endpoint not ready: {response.status}") )
except Exception as e: else:
log.debug(f"Healthcheck failed with exception: {e}") # endpoint not ready yet so bail
self.backend_errored(str(e)) log.debug(f"Healthcheck Endpoint not ready: {response.status}")
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather( await gather(
@@ -225,6 +226,9 @@ class Backend:
return await self.session.post(url=handler.endpoint, json=api_payload) return await self.session.post(url=handler.endpoint, json=api_payload)
def __check_signature(self, auth_data: AuthData) -> bool: def __check_signature(self, auth_data: AuthData) -> bool:
if self.unsecured is True:
return True
def verify_signature(message, signature): def verify_signature(message, signature):
if self.pubkey is None: if self.pubkey is None:
log.debug(f"No Public Key!") log.debug(f"No Public Key!")
@@ -331,6 +335,7 @@ class Backend:
await sleep(5) await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True
self.metrics._model_loaded( self.metrics._model_loaded(
max_throughput=max_throughput, max_throughput=max_throughput,
) )
+7 -4
View File
@@ -8,7 +8,6 @@ from aiohttp import web, ClientResponse
import inspect import inspect
import psutil import psutil
import requests
""" """
@@ -206,13 +205,13 @@ class ModelMetrics:
workload_received: float workload_received: float
workload_cancelled: float workload_cancelled: float
workload_errored: float workload_errored: float
workload_pending: float
# these are not # these are not
cur_perf: float workload_pending: float
error_msg: Optional[str] error_msg: Optional[str]
max_throughput: float max_throughput: float
requests_recieved: Set[int] = field(default_factory=set) requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set) requests_working: Set[int] = field(default_factory=set)
last_update: float = field(default_factory=time.time)
@classmethod @classmethod
def empty(cls): def empty(cls):
@@ -221,12 +220,15 @@ class ModelMetrics:
workload_served=0.0, workload_served=0.0,
workload_cancelled=0.0, workload_cancelled=0.0,
workload_errored=0.0, workload_errored=0.0,
cur_perf=0.0,
workload_received=0.0, workload_received=0.0,
error_msg=None, error_msg=None,
max_throughput=0.0, max_throughput=0.0,
) )
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property @property
def workload_processing(self) -> float: def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0) return max(self.workload_received - self.workload_cancelled, 0.0)
@@ -240,6 +242,7 @@ class ModelMetrics:
self.workload_received = 0 self.workload_received = 0
self.workload_cancelled = 0 self.workload_cancelled = 0
self.workload_errored = 0 self.workload_errored = 0
self.last_update = time.time()
@dataclass @dataclass
+11 -14
View File
@@ -5,7 +5,6 @@ import json
from asyncio import sleep from asyncio import sleep
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from functools import cache from functools import cache
from urllib.parse import urljoin
import requests import requests
@@ -47,33 +46,31 @@ class Metrics:
self.model_metrics.requests_recieved.add(reqnum) self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum) self.model_metrics.requests_working.add(reqnum)
def _request_end( def _request_end(self, workload: float, reqnum: int) -> None:
self, workload: float, req_response_time: float, reqnum: int
) -> None:
""" """
this function is called after a response from model API is received. this function is called after handling of a request ends, regardless of the outcome
""" """
self.model_metrics.workload_served += workload
self.model_metrics.workload_pending -= workload self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum) self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.cur_perf = workload / req_response_time
def _request_success(self, workload: float) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += workload
self.update_pending = True self.update_pending = True
def _request_errored(self, workload: float, reqnum: int) -> None: def _request_errored(self, workload: float) -> None:
""" """
this function is called if model API returns an error this function is called if model API returns an error
""" """
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_errored += workload self.model_metrics.workload_errored += workload
self.model_metrics.requests_working.discard(reqnum)
def _request_canceled(self, workload: float, reqnum: int) -> None: def _request_canceled(self, workload: float) -> None:
""" """
this function is called if client drops connection before model API has responded this function is called if client drops connection before model API has responded
""" """
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_cancelled += workload self.model_metrics.workload_cancelled += workload
self.model_metrics.requests_working.discard(reqnum)
async def _send_metrics_loop(self) -> Awaitable[NoReturn]: async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True: while True:
@@ -119,7 +116,7 @@ class Metrics:
def send_data(report_addr: str) -> None: def send_data(report_addr: str) -> None:
data = compute_autoscaler_data() data = compute_autoscaler_data()
full_path = urljoin(report_addr, "/worker_status/") full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug( log.debug(
"\n".join( "\n".join(
[ [
+24 -9
View File
@@ -53,6 +53,13 @@ test_args.add_argument(
default="https://run.vast.ai", default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only", help="Call local autoscaler instead of prod, for dev use only",
) )
test_args.add_argument(
"-i",
dest="instance",
type=str,
default="prod",
help="Autoscaler shard to run the command against, default: prod",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]] GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
@@ -70,6 +77,7 @@ class ClientState:
api_key: str api_key: str
server_url: str server_url: str
worker_endpoint: str worker_endpoint: str
instance: str
payload: ApiPayload payload: ApiPayload
url: str = "" url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint status: ClientStatus = ClientStatus.FetchEndpoint
@@ -79,11 +87,7 @@ class ClientState:
def make_call(self): def make_call(self):
self.status = ClientStatus.FetchEndpoint self.status = ClientStatus.FetchEndpoint
endpoint_api_key = Endpoint.get_endpoint_api_key( if not self.api_key:
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
self.as_error.append( self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key", f"Endpoint {self.endpoint_group_name} not found for API key",
) )
@@ -91,12 +95,14 @@ class ClientState:
return return
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": endpoint_api_key, "api_key": self.api_key,
"cost": self.payload.count_workload(), "cost": self.payload.count_workload(),
} }
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post( response = requests.post(
urljoin(self.server_url, "/route/"), urljoin(self.server_url, "/route/"),
json=route_payload, json=route_payload,
headers=headers,
timeout=4, timeout=4,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -135,6 +141,7 @@ class ClientState:
try: try:
self.make_call() self.make_call()
except Exception as e: except Exception as e:
print(e)
self.status = ClientStatus.Error self.status = ClientStatus.Error
_ = e _ = e
self.conn_errors[self.url] += 1 self.conn_errors[self.url] += 1
@@ -226,6 +233,7 @@ def run_test(
server_url: str, server_url: str,
worker_endpoint: str, worker_endpoint: str,
payload_cls: Type[ApiPayload], payload_cls: Type[ApiPayload],
instance: str,
): ):
threads = [] threads = []
@@ -234,8 +242,7 @@ def run_test(
print_thread.daemon = True # makes threads get killed on program exit print_thread.daemon = True # makes threads get killed on program exit
print_thread.start() print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=endpoint_group_name, endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
account_api_key=api_key,
) )
if not endpoint_api_key: if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key") log.debug(f"Endpoint {endpoint_group_name} not found for API key")
@@ -248,6 +255,7 @@ def run_test(
server_url=server_url, server_url=server_url,
worker_endpoint=worker_endpoint, worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(), payload=payload_cls.for_test(),
instance=instance,
) )
clients.append(client) clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=()) thread = threading.Thread(target=client.simulate_user, args=())
@@ -281,12 +289,19 @@ def test_load_cmd(
args = arg_parser.parse_args() args = arg_parser.parse_args()
if hasattr(args, "comfy_model"): if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
run_test( run_test(
num_requests=args.num_requests, num_requests=args.num_requests,
requests_per_second=args.requests_per_second, requests_per_second=args.requests_per_second,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url, server_url=server_url,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint, worker_endpoint=endpoint,
payload_cls=payload_cls, payload_cls=payload_cls,
instance=args.instance,
) )
+3 -2
View File
@@ -1,4 +1,4 @@
aiohttp~=3.11 aiohttp[speedups]==3.10.1
anyio~=4.4 anyio~=4.4
lib~=4.0 lib~=4.0
nltk~=3.9 nltk~=3.9
@@ -6,4 +6,5 @@ psutil~=6.0
pycryptodome~=3.20 pycryptodome~=3.20
Requests~=2.32 Requests~=2.32
transformers~=4.52 transformers~=4.52
utils~=1.0 utils==1.0.*
hf_transfer>=0.1.9
+16 -14
View File
@@ -46,17 +46,19 @@ env | grep _ >> /etc/environment;
if [ ! -d "$ENV_PATH" ] if [ ! -d "$ENV_PATH" ]
then then
apt install -y python3.10-venv
echo "setting up venv" echo "setting up venv"
curl -LsSf https://astral.sh/uv/install.sh | sh
source ~/.local/bin/env
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR" git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
python3 -m venv "$WORKSPACE_DIR/worker-env" uv venv --managed-python "$WORKSPACE_DIR/worker-env" -p 3.10
source "$WORKSPACE_DIR/worker-env/bin/activate" source "$WORKSPACE_DIR/worker-env/bin/activate"
pip install -r vast-pyworker/requirements.txt uv pip install -r vast-pyworker/requirements.txt
touch ~/.no_auto_tmux touch ~/.no_auto_tmux
else else
source ~/.local/bin/env
source "$WORKSPACE_DIR/worker-env/bin/activate" source "$WORKSPACE_DIR/worker-env/bin/activate"
echo "environment activated" echo "environment activated"
echo "venv: $VIRTUAL_ENV" echo "venv: $VIRTUAL_ENV"
@@ -87,23 +89,23 @@ if [ "$USE_SSL" = true ]; then
IP.1 = 0.0.0.0 IP.1 = 0.0.0.0
EOF EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \ -nodes \
-sha256 \ -sha256 \
-keyout /etc/instance.key \ -keyout /etc/instance.key \
-out /etc/instance.csr \ -out /etc/instance.csr \
-config /etc/openssl-san.cnf -config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \ curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \ --data-binary @//etc/instance.csr \
-X \ -X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi fi
export REPORT_ADDR WORKER_PORT USE_SSL export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR" cd "$SERVER_DIR"
+6 -2
View File
@@ -17,7 +17,9 @@ class Endpoint:
""" """
@staticmethod @staticmethod
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]: def get_endpoint_api_key(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[str]:
""" """
Fetch endpoint API key from VastAI console following the healthcheck pattern. Fetch endpoint API key from VastAI console following the healthcheck pattern.
@@ -33,7 +35,9 @@ class Endpoint:
try: try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get(vast_console_url, headers=headers) response = requests.get(
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
)
if response.status_code != 200: if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
+1
View File
@@ -153,6 +153,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try:
+80
View File
@@ -0,0 +1,80 @@
# OpenAI Compatible PyWorker
This is the base PyWorker for OpenAI compatible inference servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
## Instance Setup
1. Pick a template
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
Several examples have been provided in the client to help you get started with your own implementation.
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
```
+77
View File
@@ -0,0 +1,77 @@
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
## Configuration
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
| Variable | Default Value | Used For |
| --- | --- | --- |
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
## Usage
We have provided a demonstration client to help you implement this template into your own infrastructure
### Client Setup
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
```
View File
+594
View File
@@ -0,0 +1,594 @@
import logging
import sys
import json
import subprocess
from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List
import requests
from utils.endpoint_util import Endpoint
from .data_types.client import CompletionConfig, ChatCompletionConfig
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient:
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(url, json=req_data, stream=stream)
elif method.upper() == "GET":
response = requests.get(url, params=req_data, stream=stream)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
class ToolManager:
"""Handles tool definitions and execution"""
@staticmethod
def list_files() -> str:
"""Execute ls on current directory"""
try:
result = subprocess.run(
["ls", "-la", "."], capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
return result.stdout
else:
return f"Error: {result.stderr}"
except Exception as e:
return f"Error running ls: {e}"
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition"""
return [
{
"type": "function",
"function": {
"name": "list_files",
"description": "List files and directories in the cwd",
"parameters": {"type": "object", "properties": {}, "required": []},
},
}
]
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
function_name = tool_call["function"]["name"]
if function_name == "list_files":
return self.list_files()
else:
raise ValueError(f"Unknown tool function: {function_name}")
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client
self.model = model
self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
full_response = ""
reasoning_content = ""
reasoning_started = False
content_started = False
for chunk in response_stream:
# Normalize the chunk
if isinstance(chunk, str):
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
# Parse delta from the chunk
choices = parsed_chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
print("\n🧠 Reasoning: ", end="", flush=True)
reasoning_started = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
reasoning_content += reasoning_token
# Print content token
if content_token:
if not content_started:
if show_reasoning and reasoning_started:
print(f"\n💬 Response: ", end="", flush=True)
else:
print("Assistant: ", end="", flush=True)
content_started = True
print(content_token, end="", flush=True)
full_response += content_token
print() # Ensure newline after response
if show_reasoning:
if reasoning_started or content_started:
print("\nStreaming completed.")
if reasoning_started:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if content_started:
print(f"Response tokens: {len(full_response.split())}")
return full_response
def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...")
# Try a simple request with minimal tools to test support
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [
{
"type": "function",
"function": {"name": "test_function", "description": "Test function"},
}
]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none", # Don't actually call the tool
)
try:
response = self.client.call_chat_completions(config)
return True
except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}")
return False
def demo_completions(self) -> None:
"""Demo: test basic completions endpoint"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60)
print("TOOL USE DEMO: List Directory Contents")
print("=" * 60)
# Test if tools are supported first
if not self.test_tool_support():
return
# Request with tool available
messages = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
)
log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config)
if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use")
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
print(f"Assistant response: {message.get('content', 'No content')}")
# Check for tool calls
tool_calls = message.get("tool_calls")
if not tool_calls:
raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}")
# Execute the tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}")
tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}")
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result,
}
)
# Get final response
final_config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
)
print("Getting final response...")
final_response = self.client.call_chat_completions(final_config)
if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {})
final_content = final_message.get("content", "")
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print(final_content)
print("=" * 60)
def interactive_chat(self) -> None:
"""Interactive chat session with streaming"""
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history")
print()
messages = []
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif user_input.lower() == "clear":
messages = []
print("Chat history cleared")
continue
elif not user_input:
continue
messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(
response, show_reasoning=True
)
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
except KeyboardInterrupt:
print("\n👋 Chat interrupted. Goodbye!")
break
except Exception as e:
log.error(f"\nError: {e}")
continue
def main():
"""Main function with CLI switches for different tests"""
from lib.test_utils import test_args
# Add mandatory model argument
test_args.add_argument(
"--model", required=True, help="Model to use for requests (required)"
)
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
args = test_args.parse_args()
# Check that only one test mode is selected
test_modes = [
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session")
print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1)
elif selected_count > 1:
print("Please specify exactly one test mode")
sys.exit(1)
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=args.server_url,
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
# Run the selected test
if args.completion:
demo.demo_completions()
elif args.chat:
demo.demo_chat(use_streaming=False)
elif args.chat_stream:
demo.demo_chat(use_streaming=True)
elif args.tools:
demo.demo_ls_tool()
elif args.interactive:
demo.interactive_chat()
except Exception as e:
log.error(f"Error during test: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()
+58
View File
@@ -0,0 +1,58 @@
import json
from dataclasses import dataclass, field, fields, is_dataclass
from typing import Optional, List, Dict, Any
class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj):
return {
field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [self._serialize_recursive(item) for item in obj]
elif isinstance(obj, set):
return [self._serialize_recursive(item) for item in obj]
else:
return obj
def to_dict(self) -> Dict[str, Any]:
return self._serialize_recursive(self)
def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent)
@dataclass
class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests"""
model: str
prompt: str = "Hello"
max_tokens: int = 256
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
@dataclass
class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests"""
model: str
messages: list = field(default_factory=list)
max_tokens: int = 2096
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
tool_choice: str = "auto"
def __post_init__(self):
if self.messages is None:
self.messages = [{"role": "user", "content": "Hello"}]
+182
View File
@@ -0,0 +1,182 @@
import os, json, random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse
import nltk
import logging
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__)
"""
Generic dataclass accepts any dictionary in input.
"""
@dataclass
class GenericData(ApiPayload, ABC):
input: Dict[str, Any]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls(input=data["input"])
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
errors = {}
# Validate required parameters
required_params = ["input"]
for param in required_params:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
try:
# Create clean data dict and delegate to from_dict
clean_data = {"input": json_msg["input"]}
return cls.from_dict(clean_data)
except (json.JSONDecodeError, JsonDataException) as e:
errors["parameters"] = str(e)
raise JsonDataException(errors)
@classmethod
@abstractmethod
def for_test(cls) -> "GenericData":
pass
def generate_payload_json(self) -> Dict[str, Any]:
return self.input
def count_workload(self) -> int:
return self.input.get("max_tokens", 0)
@dataclass
class GenericHandler(EndpointHandler[GenericData], ABC):
@property
@abstractmethod
def endpoint(self) -> str:
pass
@property
def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod
def payload_cls(cls) -> Type[GenericData]:
return GenericData
@abstractmethod
def make_benchmark_payload(self) -> GenericData:
pass
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=200,
content_type=model_response.content_type,
)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclass
class CompletionsData(GenericData):
@classmethod
def for_test(cls) -> "CompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
test_input = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class CompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/completions"
@classmethod
def payload_cls(cls) -> Type[CompletionsData]:
return CompletionsData
def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test()
@dataclass
class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation"""
@classmethod
def for_test(cls) -> "ChatCompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
# Chat completions use messages format instead of prompt
test_input = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class ChatCompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/chat/completions"
@classmethod
def payload_cls(cls) -> Type[ChatCompletionsData]:
return ChatCompletionsData
def make_benchmark_payload(self) -> ChatCompletionsData:
return ChatCompletionsData.for_test()
+60
View File
@@ -0,0 +1,60 @@
import os
import logging
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
from aiohttp import web
from lib.backend import Backend, LogAction
from lib.server import start_server
# This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI
"Error: ShardCannotStart", # TGI
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
+28
View File
@@ -0,0 +1,28 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types.server import CompletionsData
import os
WORKER_ENDPOINT = "/v1/completions"
if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set
model_name_set = os.environ.get("MODEL_NAME") is not None
# Add model argument - required only if MODEL_NAME is not set
test_args.add_argument(
"--model",
dest="model",
required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)",
)
# Parse known args to get model early, before test_load_cmd adds its args
known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided
if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Now call test_load_cmd normally - it will add its own args and re-parse
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
+1
View File
@@ -100,6 +100,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try: