diff --git a/workers/openai/README.md b/workers/openai/README.md index 2436784..0dbaaa4 100644 --- a/workers/openai/README.md +++ b/workers/openai/README.md @@ -34,38 +34,20 @@ uv pip install -r requirements.txt Several examples have been provided in the client to help you get started with your own implementation. -### Completions - -Call to `/v1/completions` with json response +First, set your API key as an environment variable: ```bash -python -m workers.openai.client -k -e --completion --model +export VAST_API_KEY= ``` -### Chat Completion (json) - -Call to `/v1/chat/completions` with json response - -```bash -python -m workers.openai.client -k -e --chat --model -``` +The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively. ### Chat Completion (streaming) Call to `/v1/chat/completions` with streaming response ```bash -python -m workers.openai.client -k -e --chat-stream --model -``` - -### Tool Use (json) - -Call to `/v1/chat/completions` with tool and json response. - -This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. - -```bash -python -m workers.openai.client -k -e --tools --model +python -m workers.openai.client --chat-stream --endpoint --model ``` ### Interactive Chat (streaming) @@ -75,6 +57,32 @@ Interactive session with calls to `/v1/chat/completions`. Type `clear` to clear the chat history or `quit` to exit. ```bash -python -m workers.openai.client -k -e --interactive --model +python -m workers.openai.client --interactive --endpoint --model +``` + +### Chat Completion (json) + +Call to `/v1/chat/completions` with json response + +```bash +python -m workers.openai.client --chat --endpoint --model +``` + +### Tool Use (json) + +Call to `/v1/chat/completions` with tool and json response. + +This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. + +```bash +python -m workers.openai.client --tools --endpoint --model +``` + +### Completions + +Call to `/v1/completions` with json response + +```bash +python -m workers.openai.client --completion --endpoint --model ``` diff --git a/workers/openai/client.py b/workers/openai/client.py index 8c88444..a92ad95 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -18,7 +18,7 @@ logging.basicConfig( log = logging.getLogger(__file__) # ---------------------- Prompts ---------------------- -COMPLETIONS_PROMPT = "the capital of USA is" +COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by" CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." TOOLS_PROMPT = ( "Can you list the files in the current working directory and tell me what you see? " @@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[ # ---- OpenAI-compatible calls (non-streaming) ---- -async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: +async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) return resp["response"] -async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: +async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis return resp["response"] # ---- Streaming variants ---- -async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): +async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs): - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) return resp["response"] # async generator -async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): +async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs): - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L class APIDemo: """Demo and testing functionality for the API client""" - def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): + def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None): self.client = client self.model = model + self.endpoint_name = endpoint_name self.tool_manager = tool_manager or ToolManager() # ----- Streaming handler ----- @@ -185,10 +186,15 @@ class APIDemo: reasoning_content = "" printed_reasoning = False printed_answer = False + finish_reason = None async for chunk in stream: choice = (chunk.get("choices") or [{}])[0] delta = choice.get("delta", {}) + + # Track finish reason + if choice.get("finish_reason"): + finish_reason = choice.get("finish_reason") # reasoning tokens rc = delta.get("reasoning_content") @@ -219,6 +225,8 @@ class APIDemo: print(f"Reasoning tokens: {len(reasoning_content.split())}") if printed_answer: print(f"Response tokens: {len(full_response.split())}") + if finish_reason: + print(f"Finish reason: {finish_reason}") return full_response @@ -231,6 +239,7 @@ class APIDemo: client=self.client, model=self.model, prompt=COMPLETIONS_PROMPT, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ) @@ -249,6 +258,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE ) @@ -261,6 +271,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE ) @@ -287,6 +298,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, tools=minimal_tool, tool_choice="none", max_tokens=10 @@ -312,6 +324,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, tools=self.tool_manager.get_ls_tool_definition(), tool_choice="auto", max_tokens=MAX_TOKENS, @@ -389,6 +402,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ) @@ -427,7 +441,6 @@ class APIDemo: print("=" * 60) print("INTERACTIVE STREAMING CHAT") print("=" * 60) - print(f"Using model: {self.model}") print("Type 'quit' to exit, 'clear' to clear history") print() @@ -453,7 +466,8 @@ class APIDemo: stream = await stream_chat_completions( client=self.client, model=self.model, - messages=messages, + messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=0.7 ) @@ -473,8 +487,8 @@ class APIDemo: # ---------------------- CLI ---------------------- def build_arg_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") - p.add_argument("--model", required=True, help="Model to use for requests (required)") - p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") + p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") modes = p.add_mutually_exclusive_group(required=False) modes.add_argument("--completion", action="store_true", help="Test completions endpoint") @@ -502,12 +516,14 @@ async def main_async(): print("Please specify exactly one test mode") sys.exit(1) - print(f"Using model: {args.model}") print("=" * 60) + print(f"Using model: {args.model}") + print(f"Using endpoint: {args.endpoint}") + try: async with Serverless() as client: - demo = APIDemo(client, args.model, ToolManager()) + demo = APIDemo(client, args.model, args.endpoint, ToolManager()) if args.completion: await demo.demo_completions()