Files
pyworker/workers/openai/client.py
T

538 lines
20 KiB
Python
Raw Normal View History

2025-07-16 09:46:26 +01:00
import logging
import json
import os
import sys
2025-07-16 09:46:26 +01:00
import subprocess
import argparse
from typing import Any, Dict, List, Optional
from vastai import Serverless
import asyncio
2025-07-16 09:46:26 +01:00
# ---------------------- Logging ----------------------
2025-07-16 09:46:26 +01:00
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
2025-07-16 09:46:26 +01:00
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
)
2025-07-16 09:46:26 +01:00
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
2025-07-16 09:46:26 +01:00
# ---------------------- Tooling ----------------------
2025-07-16 09:46:26 +01:00
class ToolManager:
"""Handles tool definitions and execution"""
2025-07-16 09:46:26 +01:00
@staticmethod
def list_files() -> str:
"""Execute ls on current directory"""
try:
result = subprocess.run(
["ls", "-la", "."], capture_output=True, text=True, timeout=10
)
2025-07-16 09:46:26 +01:00
if result.returncode == 0:
return result.stdout
else:
return f"Error: {result.stderr}"
except Exception as e:
return f"Error running ls: {e}"
2025-07-16 09:46:26 +01:00
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""OpenAI-compatible tool schema"""
return [
{
"type": "function",
"function": {
"name": "list_files",
"description": "List files and directories in the cwd",
"parameters": {"type": "object", "properties": {}, "required": []},
},
2025-07-16 09:46:26 +01:00
}
]
2025-07-16 09:46:26 +01:00
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
function_name = (tool_call.get("function") or {}).get("name")
2025-07-16 09:46:26 +01:00
if function_name == "list_files":
return self.list_files()
raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
2025-07-16 09:46:26 +01:00
endpoint = await client.get_endpoint(name=endpoint_name)
2025-07-16 09:46:26 +01:00
payload = {
2025-12-15 22:33:03 -05:00
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
2025-12-15 22:33:03 -05:00
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
2025-12-15 22:33:03 -05:00
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
2025-12-15 22:33:03 -05:00
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
2025-12-15 22:33:03 -05:00
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
2025-12-15 22:33:03 -05:00
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
2025-12-15 22:33:03 -05:00
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
2025-12-15 22:33:03 -05:00
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
2025-07-16 09:46:26 +01:00
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
2025-07-16 09:46:26 +01:00
self.client = client
self.model = model
self.endpoint_name = endpoint_name
2025-07-16 09:46:26 +01:00
self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler -----
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
2025-07-16 09:46:26 +01:00
full_response = ""
reasoning_content = ""
printed_reasoning = False
printed_answer = False
finish_reason = None
2025-07-16 09:46:26 +01:00
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
2025-07-16 09:46:26 +01:00
# reasoning tokens
rc = delta.get("reasoning_content")
if rc and show_reasoning:
if not printed_reasoning:
2025-07-16 09:46:26 +01:00
print("\n🧠 Reasoning: ", end="", flush=True)
printed_reasoning = True
print(rc, end="", flush=True)
reasoning_content += rc
# content tokens
content_part = delta.get("content")
if content_part:
if not printed_answer:
if show_reasoning and printed_reasoning:
print("\n💬 Response: ", end="", flush=True)
2025-07-16 09:46:26 +01:00
else:
print("Assistant: ", end="", flush=True)
printed_answer = True
print(content_part, end="", flush=True)
full_response += content_part
2025-07-16 09:46:26 +01:00
print() # newline
2025-07-16 09:46:26 +01:00
if show_reasoning:
if printed_reasoning or printed_answer:
2025-07-16 09:46:26 +01:00
print("\nStreaming completed.")
if printed_reasoning:
2025-07-16 09:46:26 +01:00
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer:
2025-07-16 09:46:26 +01:00
print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
2025-07-16 09:46:26 +01:00
return full_response
async def demo_completions(self) -> None:
2025-07-16 09:46:26 +01:00
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
response = await call_completions(
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
2025-07-16 09:46:26 +01:00
)
print("\nResponse:")
print(json.dumps(response, indent=2))
async def demo_chat(self, use_streaming: bool = True) -> None:
2025-07-16 09:46:26 +01:00
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
2025-07-16 09:46:26 +01:00
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
2025-07-16 09:46:26 +01:00
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
2025-07-16 09:46:26 +01:00
try:
await self.handle_streaming_response(stream, show_reasoning=True)
2025-07-16 09:46:26 +01:00
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
2025-07-16 09:46:26 +01:00
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [
{
"type": "function",
"function": {"name": "test_function", "description": "Test function"},
}
]
try:
_ = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
)
return True
except Exception as e:
log.error("Endpoint does not support tool calling: %s", e)
return False
2025-07-16 09:46:26 +01:00
async def demo_ls_tool(self) -> None:
"""Ask to list files using function calling, then provide final analysis"""
2025-07-16 09:46:26 +01:00
print("=" * 60)
print("TOOL USE DEMO: List Directory Contents")
print("=" * 60)
if not await self.test_tool_support():
2025-07-16 09:46:26 +01:00
return
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
# First pass: let the model decide tools, stream tool_calls and partial content
stream = await stream_chat_completions(
client=self.client,
2025-07-16 09:46:26 +01:00
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
2025-07-16 09:46:26 +01:00
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
2025-07-16 09:46:26 +01:00
)
assistant_content_buf: List[str] = []
tool_calls_state: Dict[int, Dict[str, Any]] = {}
printed_reasoning = False
printed_answer = False
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc = delta.get("reasoning_content")
if rc:
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
content_part = delta.get("content")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
if "tool_calls" in delta and delta["tool_calls"]:
for tc_delta in delta["tool_calls"]:
_merge_tool_call_delta(tool_calls_state, tc_delta)
# If no tool calls, were done.
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
# Build assistant message with tool_calls
assistant_message = {
"role": "assistant",
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
}
messages.append(assistant_message)
# Execute tools and feed results back
for tc in assistant_message["tool_calls"]:
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
try:
args = json.loads(raw_args) if raw_args.strip() else {}
except Exception as e:
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
continue
try:
if tool_name == "list_files":
tool_result = self.tool_manager.list_files()
else:
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
except Exception as e:
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("\n[Tool executed]", tool_name)
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
# Second pass: get final streamed answer after tool results
stream2 = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
final_buf = []
printed_reasoning2 = False
printed_answer2 = False
async for chunk in stream2:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print("".join(final_buf))
print("=" * 60)
async def interactive_chat(self) -> None:
2025-07-16 09:46:26 +01:00
"""Interactive chat session with streaming"""
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print("Type 'quit' to exit, 'clear' to clear history")
print()
messages: List[Dict[str, Any]] = []
2025-07-16 09:46:26 +01:00
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
2025-07-16 09:46:26 +01:00
print("👋 Goodbye!")
break
elif user_input.lower() == "clear":
2025-07-16 09:46:26 +01:00
messages = []
print("Chat history cleared")
continue
elif not user_input:
continue
2025-07-16 09:46:26 +01:00
messages.append({"role": "user", "content": user_input})
2025-07-16 09:46:26 +01:00
print("Assistant: ", end="", flush=True)
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=0.7
)
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
2025-07-16 09:46:26 +01:00
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
2025-07-16 09:46:26 +01:00
except KeyboardInterrupt:
print("\n👋 Chat interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
2025-07-16 09:46:26 +01:00
continue
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
if selected == 0:
2025-07-16 09:46:26 +01:00
print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool")
2025-07-16 09:46:26 +01:00
print(" --interactive : Start interactive streaming chat session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
2025-07-16 09:46:26 +01:00
sys.exit(1)
elif selected > 1:
2025-07-16 09:46:26 +01:00
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if args.completion:
await demo.demo_completions()
elif args.chat:
await demo.demo_chat(use_streaming=False)
elif args.chat_stream:
await demo.demo_chat(use_streaming=True)
elif args.tools:
await demo.demo_ls_tool()
elif args.interactive:
await demo.interactive_chat()
2025-07-16 09:46:26 +01:00
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
2025-07-16 09:46:26 +01:00
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main_async())