defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present

This commit is contained in:
Colter Downing
2025-12-03 16:57:28 -08:00
parent 0339b471c5
commit adedb8ba90
2 changed files with 63 additions and 39 deletions
+31 -23
View File
@@ -34,38 +34,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation. Several examples have been provided in the client to help you get started with your own implementation.
### Completions First, set your API key as an environment variable:
Call to `/v1/completions` with json response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME> export VAST_API_KEY=<your_api_key>
``` ```
### Chat Completion (json) The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming) ### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response Call to `/v1/chat/completions` with streaming response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME> python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
``` ```
### Interactive Chat (streaming) ### Interactive Chat (streaming)
@@ -75,6 +57,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit. Type `clear` to clear the chat history or `quit` to exit.
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME> python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
``` ```
+32 -16
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ---------------------- # ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is" COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = ( TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? " "Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ---- # ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"] return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"] return resp["response"]
# ---- Streaming variants ---- # ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client self.client = client
self.model = model self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler ----- # ----- Streaming handler -----
@@ -185,10 +186,15 @@ class APIDemo:
reasoning_content = "" reasoning_content = ""
printed_reasoning = False printed_reasoning = False
printed_answer = False printed_answer = False
finish_reason = None
async for chunk in stream: async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0] choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {}) delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens # reasoning tokens
rc = delta.get("reasoning_content") rc = delta.get("reasoning_content")
@@ -219,6 +225,8 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer: if printed_answer:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response return full_response
@@ -231,6 +239,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=COMPLETIONS_PROMPT, prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -249,6 +258,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -261,6 +271,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -287,6 +298,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none", tool_choice="none",
max_tokens=10 max_tokens=10
@@ -312,6 +324,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
@@ -389,6 +402,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -427,7 +441,6 @@ class APIDemo:
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
print("=" * 60) print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
@@ -453,7 +466,8 @@ class APIDemo:
stream = await stream_chat_completions( stream = await stream_chat_completions(
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=0.7 temperature=0.7
) )
@@ -473,8 +487,8 @@ class APIDemo:
# ---------------------- CLI ---------------------- # ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)") p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False) modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint") modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -502,12 +516,14 @@ async def main_async():
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
print(f"Using model: {args.model}")
print("=" * 60) print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try: try:
async with Serverless() as client: async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager()) demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if args.completion: if args.completion:
await demo.demo_completions() await demo.demo_completions()