Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ac51947b4 |
+43
-45
@@ -126,7 +126,7 @@ class Backend:
|
|||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
await request.wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
self.metrics._request_canceled(workload=workload)
|
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
@@ -141,6 +141,7 @@ class Backend:
|
|||||||
else:
|
else:
|
||||||
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
||||||
try:
|
try:
|
||||||
|
start_time = time.time()
|
||||||
response = await self.__call_api(handler=handler, payload=payload)
|
response = await self.__call_api(handler=handler, payload=payload)
|
||||||
status_code = response.status
|
status_code = response.status
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -152,17 +153,19 @@ class Backend:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
res = await handler.generate_client_response(request, response)
|
res = await handler.generate_client_response(request, response)
|
||||||
self.metrics._request_success(workload=workload)
|
self.metrics._request_end(
|
||||||
|
workload=workload,
|
||||||
|
req_response_time=time.time() - start_time,
|
||||||
|
reqnum=auth_data.reqnum,
|
||||||
|
)
|
||||||
return res
|
return res
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
log.debug(f"[backend] Request error: {e}")
|
log.debug(f"[backend] Request error: {e}")
|
||||||
self.metrics._request_errored(workload=workload)
|
self.metrics._request_errored(
|
||||||
|
workload=workload, reqnum=auth_data.reqnum
|
||||||
|
)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
finally:
|
||||||
self.metrics._request_end(
|
|
||||||
workload=workload,
|
|
||||||
reqnum=auth_data.reqnum,
|
|
||||||
)
|
|
||||||
self.sem.release()
|
self.sem.release()
|
||||||
|
|
||||||
###########
|
###########
|
||||||
@@ -183,6 +186,12 @@ class Backend:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
finally:
|
||||||
|
if request.task.cancelled():
|
||||||
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
|
self.metrics._request_canceled(
|
||||||
|
workload=workload, reqnum=auth_data.reqnum
|
||||||
|
)
|
||||||
|
|
||||||
async def __healthcheck(self):
|
async def __healthcheck(self):
|
||||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||||
@@ -280,52 +289,41 @@ class Backend:
|
|||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log.debug("Initial run to trigger model loading...")
|
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
|
||||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
|
||||||
|
|
||||||
max_throughput = 0
|
max_throughput = 0
|
||||||
|
last_throughput = 0
|
||||||
sum_throughput = 0
|
sum_throughput = 0
|
||||||
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
for run in range(self.benchmark_handler.benchmark_runs + 1):
|
||||||
|
|
||||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
tasks = []
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
total_workload = 0
|
res = await self.__call_api(
|
||||||
|
handler=self.benchmark_handler, payload=payload
|
||||||
for _ in range(concurrent_requests):
|
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
|
||||||
total_workload += payload.count_workload()
|
|
||||||
tasks.append(
|
|
||||||
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
|
||||||
)
|
|
||||||
|
|
||||||
responses = await gather(*tasks)
|
|
||||||
time_elapsed = time.time() - start
|
|
||||||
|
|
||||||
throughput = total_workload / time_elapsed
|
|
||||||
sum_throughput += throughput
|
|
||||||
max_throughput = max(max_throughput, throughput)
|
|
||||||
|
|
||||||
# Log results for debugging
|
|
||||||
log.debug(
|
|
||||||
"\n".join(
|
|
||||||
[
|
|
||||||
"#" * 60,
|
|
||||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
|
||||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
|
||||||
f"Throughput: {throughput} workload/s",
|
|
||||||
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
|
||||||
"#" * 60,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
data = await res.json()
|
||||||
|
time_elapsed = time.time() - start
|
||||||
|
# first run triggers one-time loading of the model which is very slow, so we skip counting it
|
||||||
|
if run == 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
workload = payload.count_workload()
|
||||||
|
last_throughput = workload / time_elapsed
|
||||||
|
sum_throughput += last_throughput
|
||||||
|
max_throughput = max(max_throughput, last_throughput)
|
||||||
|
log.debug(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"#" * 60,
|
||||||
|
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
|
||||||
|
"",
|
||||||
|
f"response: {data}",
|
||||||
|
"#" * 60,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||||
log.debug(
|
log.debug(
|
||||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||||
)
|
)
|
||||||
|
# save max_throughput so we don't have to run benchmark again on restart of cold instances
|
||||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||||
f.write(str(max_throughput))
|
f.write(str(max_throughput))
|
||||||
return max_throughput
|
return max_throughput
|
||||||
|
|||||||
+4
-7
@@ -8,6 +8,7 @@ from aiohttp import web, ClientResponse
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -205,13 +206,13 @@ class ModelMetrics:
|
|||||||
workload_received: float
|
workload_received: float
|
||||||
workload_cancelled: float
|
workload_cancelled: float
|
||||||
workload_errored: float
|
workload_errored: float
|
||||||
# these are not
|
|
||||||
workload_pending: float
|
workload_pending: float
|
||||||
|
# these are not
|
||||||
|
cur_perf: float
|
||||||
error_msg: Optional[str]
|
error_msg: Optional[str]
|
||||||
max_throughput: float
|
max_throughput: float
|
||||||
requests_recieved: Set[int] = field(default_factory=set)
|
requests_recieved: Set[int] = field(default_factory=set)
|
||||||
requests_working: Set[int] = field(default_factory=set)
|
requests_working: Set[int] = field(default_factory=set)
|
||||||
last_update: float = field(default_factory=time.time)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls):
|
def empty(cls):
|
||||||
@@ -220,15 +221,12 @@ class ModelMetrics:
|
|||||||
workload_served=0.0,
|
workload_served=0.0,
|
||||||
workload_cancelled=0.0,
|
workload_cancelled=0.0,
|
||||||
workload_errored=0.0,
|
workload_errored=0.0,
|
||||||
|
cur_perf=0.0,
|
||||||
workload_received=0.0,
|
workload_received=0.0,
|
||||||
error_msg=None,
|
error_msg=None,
|
||||||
max_throughput=0.0,
|
max_throughput=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def cur_perf(self) -> float:
|
|
||||||
return max(self.workload_served / (time.time() - self.last_update), 0.0)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workload_processing(self) -> float:
|
def workload_processing(self) -> float:
|
||||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||||
@@ -242,7 +240,6 @@ class ModelMetrics:
|
|||||||
self.workload_received = 0
|
self.workload_received = 0
|
||||||
self.workload_cancelled = 0
|
self.workload_cancelled = 0
|
||||||
self.workload_errored = 0
|
self.workload_errored = 0
|
||||||
self.last_update = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
+13
-11
@@ -46,31 +46,33 @@ class Metrics:
|
|||||||
self.model_metrics.requests_recieved.add(reqnum)
|
self.model_metrics.requests_recieved.add(reqnum)
|
||||||
self.model_metrics.requests_working.add(reqnum)
|
self.model_metrics.requests_working.add(reqnum)
|
||||||
|
|
||||||
def _request_end(self, workload: float, reqnum: int) -> None:
|
def _request_end(
|
||||||
|
self, workload: float, req_response_time: float, reqnum: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called after handling of a request ends, regardless of the outcome
|
this function is called after a response from model API is received.
|
||||||
"""
|
|
||||||
self.model_metrics.workload_pending -= workload
|
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
|
||||||
|
|
||||||
def _request_success(self, workload: float) -> None:
|
|
||||||
"""
|
|
||||||
this function is called after a response from model API is received and forwarded.
|
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_served += workload
|
self.model_metrics.workload_served += workload
|
||||||
|
self.model_metrics.workload_pending -= workload
|
||||||
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
|
self.model_metrics.cur_perf = workload / req_response_time
|
||||||
self.update_pending = True
|
self.update_pending = True
|
||||||
|
|
||||||
def _request_errored(self, workload: float) -> None:
|
def _request_errored(self, workload: float, reqnum: int) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if model API returns an error
|
this function is called if model API returns an error
|
||||||
"""
|
"""
|
||||||
|
self.model_metrics.workload_pending -= workload
|
||||||
self.model_metrics.workload_errored += workload
|
self.model_metrics.workload_errored += workload
|
||||||
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
|
|
||||||
def _request_canceled(self, workload: float) -> None:
|
def _request_canceled(self, workload: float, reqnum: int) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if client drops connection before model API has responded
|
this function is called if client drops connection before model API has responded
|
||||||
"""
|
"""
|
||||||
|
self.model_metrics.workload_pending -= workload
|
||||||
self.model_metrics.workload_cancelled += workload
|
self.model_metrics.workload_cancelled += workload
|
||||||
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
|
|
||||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
+1
-1
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
|||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app, handler_cancellation=True)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from collections import Counter
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
from utils.ssl import get_cert_file_path
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from lib.data_types import AuthData, ApiPayload
|
from lib.data_types import AuthData, ApiPayload
|
||||||
@@ -121,11 +120,9 @@ class ClientState:
|
|||||||
self.url = worker_address
|
self.url = worker_address
|
||||||
url = urljoin(worker_address, self.worker_endpoint)
|
url = urljoin(worker_address, self.worker_endpoint)
|
||||||
self.status = ClientStatus.Generating
|
self.status = ClientStatus.Generating
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
verify=get_cert_file_path(),
|
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
self.infer_error.append(
|
self.infer_error.append(
|
||||||
|
|||||||
+2
-2
@@ -1,4 +1,4 @@
|
|||||||
aiohttp[speedups]==3.10.1
|
aiohttp==3.10.1
|
||||||
anyio~=4.4
|
anyio~=4.4
|
||||||
lib~=4.0
|
lib~=4.0
|
||||||
nltk~=3.9
|
nltk~=3.9
|
||||||
@@ -6,5 +6,5 @@ psutil~=6.0
|
|||||||
pycryptodome~=3.20
|
pycryptodome~=3.20
|
||||||
Requests~=2.32
|
Requests~=2.32
|
||||||
transformers~=4.52
|
transformers~=4.52
|
||||||
utils==1.0.*
|
utils~=1.0
|
||||||
hf_transfer>=0.1.9
|
hf_transfer>=0.1.9
|
||||||
|
|||||||
@@ -30,12 +30,7 @@ class Endpoint:
|
|||||||
Returns:
|
Returns:
|
||||||
Endpoint API key if successful, None otherwise
|
Endpoint API key if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
endpoints = {
|
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
|
||||||
"alpha": "alpha",
|
|
||||||
"candidate": "candidate",
|
|
||||||
"prod": "console",
|
|
||||||
}
|
|
||||||
vast_console_url = f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
|
||||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import tempfile
|
|
||||||
from functools import cache
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def get_cert_file_path():
|
|
||||||
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
|
||||||
response = requests.get(cert_url)
|
|
||||||
response.raise_for_status()
|
|
||||||
# Use a temporary file that is not deleted on close
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
return f.name
|
|
||||||
@@ -5,7 +5,6 @@ import requests
|
|||||||
|
|
||||||
from lib.test_utils import print_truncate_res
|
from lib.test_utils import print_truncate_res
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
from utils.ssl import get_cert_file_path
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||||
@@ -52,7 +51,6 @@ def call_default_workflow(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
verify=get_cert_file_path(),
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
@@ -143,7 +141,6 @@ def call_custom_workflow_for_sd3(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
verify=get_cert_file_path(),
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
|
|||||||
+17
-26
@@ -6,7 +6,6 @@ from urllib.parse import urljoin
|
|||||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
from utils.ssl import get_cert_file_path
|
|
||||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -29,16 +28,24 @@ class APIClient:
|
|||||||
DEFAULT_TIMEOUT = 4
|
DEFAULT_TIMEOUT = 4
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str
|
||||||
endpoint_group_name: str,
|
|
||||||
api_key: str,
|
|
||||||
server_url: str,
|
|
||||||
endpoint_api_key: str,
|
|
||||||
):
|
):
|
||||||
self.endpoint_group_name = endpoint_group_name
|
self.endpoint_group_name = endpoint_group_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.endpoint_api_key = endpoint_api_key
|
self.instance = instance
|
||||||
|
self.endpoint_api_key = self._get_endpoint_api_key()
|
||||||
|
|
||||||
|
def _get_endpoint_api_key(self) -> Optional[str]:
|
||||||
|
"""Get the endpoint API key"""
|
||||||
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=self.endpoint_group_name,
|
||||||
|
account_api_key=self.api_key,
|
||||||
|
instance=self.instance,
|
||||||
|
)
|
||||||
|
if not endpoint_api_key:
|
||||||
|
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
|
||||||
|
return endpoint_api_key
|
||||||
|
|
||||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||||
"""Get worker URL and auth data from routing service"""
|
"""Get worker URL and auth data from routing service"""
|
||||||
@@ -91,13 +98,9 @@ class APIClient:
|
|||||||
|
|
||||||
# Make the request using the specified method
|
# Make the request using the specified method
|
||||||
if method.upper() == "POST":
|
if method.upper() == "POST":
|
||||||
response = requests.post(
|
response = requests.post(url, json=req_data, stream=stream)
|
||||||
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
|
||||||
)
|
|
||||||
elif method.upper() == "GET":
|
elif method.upper() == "GET":
|
||||||
response = requests.get(
|
response = requests.get(url, params=req_data, stream=stream)
|
||||||
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
@@ -551,24 +554,12 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
|
||||||
endpoint_name=args.endpoint_group_name,
|
|
||||||
account_api_key=args.api_key,
|
|
||||||
instance=args.instance,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not endpoint_api_key:
|
|
||||||
log.error(
|
|
||||||
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Create the core API client
|
# Create the core API client
|
||||||
client = APIClient(
|
client = APIClient(
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
api_key=args.api_key,
|
api_key=args.api_key,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
endpoint_api_key=endpoint_api_key,
|
instance=args.instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tool manager and demo (passing the model parameter)
|
# Create tool manager and demo (passing the model parameter)
|
||||||
|
|||||||
@@ -124,12 +124,7 @@ class CompletionsData(GenericData):
|
|||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
test_input = {
|
test_input = {"model": model, "prompt": prompt, "temperature": 0.7}
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 500,
|
|
||||||
}
|
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
@@ -163,7 +158,6 @@ class ChatCompletionsData(GenericData):
|
|||||||
"model": model,
|
"model": model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 500,
|
|
||||||
}
|
}
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import json
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
from utils.ssl import get_cert_file_path
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
@@ -43,11 +42,7 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
|||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
print(f"url: {url}")
|
print(f"url: {url}")
|
||||||
response = requests.post(
|
response = requests.post(url, json=req_data)
|
||||||
url,
|
|
||||||
json=req_data,
|
|
||||||
verify=get_cert_file_path(),
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
res = response.json()
|
res = response.json()
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
Reference in New Issue
Block a user