From 6b5b1341a79387a0bb953eaf304335ab50a8c0bf Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 18:38:42 -0800 Subject: [PATCH] update tgi client --- workers/openai/README.md | 5 +- workers/tgi/README.md | 102 +++++++++++++++-- workers/tgi/client.py | 229 +++++++++++++++++++++++++++++++++------ 3 files changed, 290 insertions(+), 46 deletions(-) diff --git a/workers/openai/README.md b/workers/openai/README.md index 0dbaaa4..f7596f3 100644 --- a/workers/openai/README.md +++ b/workers/openai/README.md @@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. -- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended) +- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended) - [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) -- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless)) All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. -2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. ## Client Setup (Demo) diff --git a/workers/tgi/README.md b/workers/tgi/README.md index 5cf8488..9147e38 100644 --- a/workers/tgi/README.md +++ b/workers/tgi/README.md @@ -1,19 +1,103 @@ -This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints: +# HuggingFace TGI PyWorker -1. `generate`: Generates the LLM's response to a given prompt in a single request. -2. `generate_stream`: Streams the LLM's response token by token. +This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. -Both endpoints use the following API payload format: +## Instance Setup + +1. Pick a template + +This worker is compatible with any TGI backend. We have a template you can use or you can create your own. + +- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless)) + +The template can be configured via the template interface. You may want to change the model or startup arguments. + +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. + +## Client Setup (Demo) + +1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. + +```bash +git clone https://github.com/vast-ai/pyworker +cd pyworker +pip install uv +uv venv -p 3.12 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +## Using the Test Client + +The test client demonstrates both streaming and non-streaming generation using TGI's native API. + +First, set your API key as an environment variable: + +```bash +export VAST_API_KEY= +``` + +The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`. + +### Generate (Streaming) + +Call to `/generate_stream` with streaming response: + +```bash +python -m workers.tgi.client --generate-stream --endpoint +``` + +### Generate (Non-Streaming) + +Call to `/generate` with json response: + +```bash +python -m workers.tgi.client --generate --endpoint +``` + +### Interactive Session (Streaming) + +Interactive session with streaming responses. Type `quit` to exit. + +```bash +python -m workers.tgi.client --interactive --endpoint +``` + +## API Endpoints + +TGI provides two primary endpoints: + +### Generate (Non-Streaming) + +`/generate` - Returns the complete response in a single request. ```json { - "inputs": "PROMPT", + "inputs": "Your prompt here", "parameters": { - "max_new_tokens": 250 + "max_new_tokens": 1024, + "temperature": 0.7, + "return_full_text": false } } ``` -Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an -instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take -approximately 2 seconds to complete. +### Generate Stream (Streaming) + +`/generate_stream` - Streams the response token by token. + +```json +{ + "inputs": "Your prompt here", + "parameters": { + "max_new_tokens": 1024, + "temperature": 0.7, + "do_sample": true, + "return_full_text": false + } +} +``` + +## Performance Notes + +The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete. diff --git a/workers/tgi/client.py b/workers/tgi/client.py index f307602..23b40c2 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -1,61 +1,222 @@ +import logging +import json +import os +import sys +import argparse + from vastai import Serverless import asyncio -ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name -MAX_TOKENS = 1024 -PROMPT = "Think step by step: Tell me about the Python programming language." +# ---------------------- Logging ---------------------- +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) -async def call_generate(client: Serverless) -> None: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) +# ---------------------- Defaults ---------------------- +DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language." + +ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name +MAX_TOKENS = 1024 +DEFAULT_TEMPERATURE = 0.7 + + +# ---------------------- API Calls ---------------------- +async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict: + """Non-streaming generation via /generate endpoint""" + endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "inputs": PROMPT, + "inputs": prompt, "parameters": { - "max_new_tokens": MAX_TOKENS, - "temperature": 0.7, - "return_full_text": False + "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "return_full_text": False, } } - - resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) - - print(resp["response"]["generated_text"]) + log.debug("POST /generate %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"]) + return resp["response"] -async def call_generate_stream(client: Serverless) -> None: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) +async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs): + """Streaming generation via /generate_stream endpoint""" + endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "inputs": PROMPT, + "inputs": prompt, "parameters": { - "max_new_tokens": MAX_TOKENS, - "temperature": 0.7, + "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "do_sample": True, "return_full_text": False, } } - + log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500]) resp = await endpoint.request( "/generate_stream", payload, - cost=MAX_TOKENS, + cost=payload["parameters"]["max_new_tokens"], stream=True, ) - stream = resp["response"] + return resp["response"] # async generator - printed_answer = False - async for event in stream: - tok = (event.get("token") or {}).get("text") - if tok: - if not printed_answer: - printed_answer = True - print("Answer:\n", end="", flush=True) - print(tok, end="", flush=True) -async def main(): - async with Serverless() as client: - await call_generate(client) - await call_generate_stream(client) +# ---------------------- Demo Runner ---------------------- +class APIDemo: + """Demo and testing functionality for the TGI API client""" + + def __init__(self, client: Serverless, endpoint_name: str): + self.client = client + self.endpoint_name = endpoint_name + + async def handle_streaming_response(self, stream) -> str: + """Process streaming response and print tokens""" + full_response = "" + printed_answer = False + + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + if not printed_answer: + printed_answer = True + print("\nšŸ’¬ Response: ", end="", flush=True) + print(tok, end="", flush=True) + full_response += tok + + print() # newline + if printed_answer: + print(f"\nStreaming completed. Response tokens: {len(full_response.split())}") + + return full_response + + async def demo_generate(self) -> None: + """Demo non-streaming generation""" + print("=" * 60) + print("GENERATE DEMO (NON-STREAMING)") + print("=" * 60) + + response = await call_generate( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=DEFAULT_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + print(f"\nšŸ’¬ Response: {response.get('generated_text', '')}") + print(f"\nFull Response:\n{json.dumps(response, indent=2)}") + + async def demo_generate_stream(self) -> None: + """Demo streaming generation""" + print("=" * 60) + print("GENERATE DEMO (STREAMING)") + print("=" * 60) + + stream = await call_generate_stream( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=DEFAULT_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + try: + await self.handle_streaming_response(stream) + except Exception as e: + log.error("\nError during streaming: %s", e, exc_info=True) + + async def interactive_chat(self) -> None: + """Interactive session with streaming generation""" + print("=" * 60) + print("INTERACTIVE STREAMING SESSION") + print("=" * 60) + print(f"Using endpoint: {self.endpoint_name}") + print("Type 'quit' to exit") + print() + + while True: + try: + user_input = input("You: ").strip() + + if user_input.lower() == "quit": + print("šŸ‘‹ Goodbye!") + break + elif not user_input: + continue + + print("Assistant: ", end="", flush=True) + stream = await call_generate_stream( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=user_input, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + full_response = "" + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + print(tok, end="", flush=True) + full_response += tok + print() # newline + + except KeyboardInterrupt: + print("\nšŸ‘‹ Session interrupted. Goodbye!") + break + except Exception as e: + log.error("\nError: %s", e) + continue + + +# ---------------------- CLI ---------------------- +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") + + modes = p.add_mutually_exclusive_group(required=False) + modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)") + modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming") + modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session") + return p + + +async def main_async(): + args = build_arg_parser().parse_args() + + selected = sum([args.generate, args.generate_stream, args.interactive]) + if selected == 0: + print("Please specify exactly one test mode:") + print(" --generate : Test generate endpoint (non-streaming)") + print(" --generate-stream : Test generate endpoint with streaming") + print(" --interactive : Start interactive streaming session") + print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint") + sys.exit(1) + elif selected > 1: + print("Please specify exactly one test mode") + sys.exit(1) + + print("=" * 60) + print(f"Using endpoint: {args.endpoint}") + + try: + async with Serverless() as client: + demo = APIDemo(client, args.endpoint) + + if args.generate: + await demo.demo_generate() + elif args.generate_stream: + await demo.demo_generate_stream() + elif args.interactive: + await demo.interactive_chat() + + except Exception as e: + log.error("Error during test: %s", e, exc_info=True) + sys.exit(1) + if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main_async())