Files
pyworker/workers/tgi/client.py
T

62 lines
1.6 KiB
Python
Raw Normal View History

from vastai import Serverless
import asyncio
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
MAX_TOKENS = 1024
PROMPT = "Think step by step: Tell me about the Python programming language."
async def call_generate(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"return_full_text": False
}
2024-09-04 11:19:30 -07:00
}
2025-06-02 17:13:25 -07:00
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
print(resp["response"]["generated_text"])
2025-06-02 17:13:25 -07:00
2024-09-04 11:19:30 -07:00
async def call_generate_stream(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
2024-09-04 11:19:30 -07:00
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"do_sample": True,
"return_full_text": False,
}
2024-09-04 11:19:30 -07:00
}
resp = await endpoint.request(
"/generate_stream",
payload,
cost=MAX_TOKENS,
stream=True,
)
stream = resp["response"]
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main():
async with Serverless() as client:
await call_generate(client)
await call_generate_stream(client)
2024-09-04 11:19:30 -07:00
if __name__ == "__main__":
asyncio.run(main())