Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a7617162a7 | |||
| d8f51a2edc | |||
| ee57ed207b | |||
| c98d661513 | |||
| f6fd1c6ac1 | |||
| 055e346c8c | |||
| 1cedb28acf | |||
| 0397af719d | |||
| 4fdc314fd9 | |||
| 639d82f5b4 | |||
| 25db78e39d | |||
| 4e2f2311d0 | |||
| 38782d89bc | |||
| 0185216ccb | |||
| b20d9e714c | |||
| b1eb65d75d | |||
| 1d09d7fe96 | |||
| 1b37054dec | |||
| 1a1e4174b8 | |||
| b8377c4081 | |||
| 1e4fa87437 | |||
| 4c5fa03c7b | |||
| a8fe74f771 | |||
| b482de8394 | |||
| 703435d10e | |||
| 947fc5eea4 | |||
| 7c1a544b19 | |||
| 16b414676e | |||
| ba74ac8136 | |||
| 92ff412679 | |||
| fc75a64684 | |||
| b00bef547c | |||
| 3f4acb29fa | |||
| 58b078f908 | |||
| f9fdf04884 | |||
| 636f17d27f | |||
| 08c88f7527 | |||
| 8797b504af | |||
| cd946b0a9f | |||
| c595b42410 | |||
| 0bf3247a34 | |||
| 52ac4c0c1a | |||
| 8804e17201 | |||
| 4016cf9a53 |
+66
-47
@@ -11,7 +11,7 @@ from functools import cached_property
|
|||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from Crypto.Signature import pkcs1_15
|
from Crypto.Signature import pkcs1_15
|
||||||
@@ -75,7 +75,13 @@ class Backend:
|
|||||||
@cached_property
|
@cached_property
|
||||||
def session(self):
|
def session(self):
|
||||||
log.debug(f"starting session with {self.model_server_url}")
|
log.debug(f"starting session with {self.model_server_url}")
|
||||||
return ClientSession(self.model_server_url)
|
connector = TCPConnector(
|
||||||
|
force_close=True, # Required for long running jobs
|
||||||
|
enable_cleanup_closed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout = ClientTimeout(total=None)
|
||||||
|
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||||
|
|
||||||
def create_handler(
|
def create_handler(
|
||||||
self,
|
self,
|
||||||
@@ -126,7 +132,7 @@ class Backend:
|
|||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
await request.wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
self.metrics._request_canceled(workload=workload)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
@@ -141,7 +147,6 @@ class Backend:
|
|||||||
else:
|
else:
|
||||||
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
|
||||||
response = await self.__call_api(handler=handler, payload=payload)
|
response = await self.__call_api(handler=handler, payload=payload)
|
||||||
status_code = response.status
|
status_code = response.status
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -153,19 +158,17 @@ class Backend:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
res = await handler.generate_client_response(request, response)
|
res = await handler.generate_client_response(request, response)
|
||||||
self.metrics._request_end(
|
self.metrics._request_success(workload=workload)
|
||||||
workload=workload,
|
|
||||||
req_response_time=time.time() - start_time,
|
|
||||||
reqnum=auth_data.reqnum,
|
|
||||||
)
|
|
||||||
return res
|
return res
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
log.debug(f"[backend] Request error: {e}")
|
log.debug(f"[backend] Request error: {e}")
|
||||||
self.metrics._request_errored(
|
self.metrics._request_errored(workload=workload)
|
||||||
workload=workload, reqnum=auth_data.reqnum
|
|
||||||
)
|
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
finally:
|
||||||
|
self.metrics._request_end(
|
||||||
|
workload=workload,
|
||||||
|
reqnum=auth_data.reqnum,
|
||||||
|
)
|
||||||
self.sem.release()
|
self.sem.release()
|
||||||
|
|
||||||
###########
|
###########
|
||||||
@@ -186,25 +189,31 @@ class Backend:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
|
||||||
if request.task.cancelled():
|
@cached_property
|
||||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
def healthcheck_session(self):
|
||||||
self.metrics._request_canceled(
|
"""Dedicated session for healthchecks to avoid conflicts with API session"""
|
||||||
workload=workload, reqnum=auth_data.reqnum
|
log.debug("creating dedicated healthcheck session")
|
||||||
)
|
connector = TCPConnector(
|
||||||
|
force_close=True, # Keep this for isolation
|
||||||
|
enable_cleanup_closed=True,
|
||||||
|
)
|
||||||
|
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
|
||||||
|
return ClientSession(timeout=timeout, connector=connector)
|
||||||
|
|
||||||
async def __healthcheck(self):
|
async def __healthcheck(self):
|
||||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||||
if health_check_url is None:
|
if health_check_url is None:
|
||||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
return
|
return
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
await sleep(10)
|
await sleep(10)
|
||||||
if self.__start_healthcheck is False:
|
if self.__start_healthcheck is False:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
async with self.session.get(health_check_url) as response:
|
async with self.healthcheck_session.get(health_check_url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
log.debug("Healthcheck successful")
|
log.debug("Healthcheck successful")
|
||||||
elif response.status == 503:
|
elif response.status == 503:
|
||||||
@@ -213,7 +222,6 @@ class Backend:
|
|||||||
f"Healthcheck failed with status: {response.status}"
|
f"Healthcheck failed with status: {response.status}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# endpoint not ready yet so bail
|
|
||||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Healthcheck failed with exception: {e}")
|
log.debug(f"Healthcheck failed with exception: {e}")
|
||||||
@@ -289,41 +297,52 @@ class Backend:
|
|||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
log.debug("Initial run to trigger model loading...")
|
||||||
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
|
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
|
|
||||||
max_throughput = 0
|
max_throughput = 0
|
||||||
last_throughput = 0
|
|
||||||
sum_throughput = 0
|
sum_throughput = 0
|
||||||
for run in range(self.benchmark_handler.benchmark_runs + 1):
|
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
||||||
|
|
||||||
|
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
tasks = []
|
||||||
res = await self.__call_api(
|
total_workload = 0
|
||||||
handler=self.benchmark_handler, payload=payload
|
|
||||||
)
|
for _ in range(concurrent_requests):
|
||||||
data = await res.json()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
time_elapsed = time.time() - start
|
total_workload += payload.count_workload()
|
||||||
# first run triggers one-time loading of the model which is very slow, so we skip counting it
|
tasks.append(
|
||||||
if run == 0:
|
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
continue
|
|
||||||
else:
|
|
||||||
workload = payload.count_workload()
|
|
||||||
last_throughput = workload / time_elapsed
|
|
||||||
sum_throughput += last_throughput
|
|
||||||
max_throughput = max(max_throughput, last_throughput)
|
|
||||||
log.debug(
|
|
||||||
"\n".join(
|
|
||||||
[
|
|
||||||
"#" * 60,
|
|
||||||
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
|
|
||||||
"",
|
|
||||||
f"response: {data}",
|
|
||||||
"#" * 60,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
responses = await gather(*tasks)
|
||||||
|
time_elapsed = time.time() - start
|
||||||
|
|
||||||
|
throughput = total_workload / time_elapsed
|
||||||
|
sum_throughput += throughput
|
||||||
|
max_throughput = max(max_throughput, throughput)
|
||||||
|
|
||||||
|
# Log results for debugging
|
||||||
|
log.debug(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"#" * 60,
|
||||||
|
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||||
|
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||||
|
f"Throughput: {throughput} workload/s",
|
||||||
|
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
||||||
|
"#" * 60,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||||
log.debug(
|
log.debug(
|
||||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||||
)
|
)
|
||||||
# save max_throughput so we don't have to run benchmark again on restart of cold instances
|
|
||||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||||
f.write(str(max_throughput))
|
f.write(str(max_throughput))
|
||||||
return max_throughput
|
return max_throughput
|
||||||
|
|||||||
+7
-4
@@ -8,7 +8,6 @@ from aiohttp import web, ClientResponse
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -206,13 +205,13 @@ class ModelMetrics:
|
|||||||
workload_received: float
|
workload_received: float
|
||||||
workload_cancelled: float
|
workload_cancelled: float
|
||||||
workload_errored: float
|
workload_errored: float
|
||||||
workload_pending: float
|
|
||||||
# these are not
|
# these are not
|
||||||
cur_perf: float
|
workload_pending: float
|
||||||
error_msg: Optional[str]
|
error_msg: Optional[str]
|
||||||
max_throughput: float
|
max_throughput: float
|
||||||
requests_recieved: Set[int] = field(default_factory=set)
|
requests_recieved: Set[int] = field(default_factory=set)
|
||||||
requests_working: Set[int] = field(default_factory=set)
|
requests_working: Set[int] = field(default_factory=set)
|
||||||
|
last_update: float = field(default_factory=time.time)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls):
|
def empty(cls):
|
||||||
@@ -221,12 +220,15 @@ class ModelMetrics:
|
|||||||
workload_served=0.0,
|
workload_served=0.0,
|
||||||
workload_cancelled=0.0,
|
workload_cancelled=0.0,
|
||||||
workload_errored=0.0,
|
workload_errored=0.0,
|
||||||
cur_perf=0.0,
|
|
||||||
workload_received=0.0,
|
workload_received=0.0,
|
||||||
error_msg=None,
|
error_msg=None,
|
||||||
max_throughput=0.0,
|
max_throughput=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_perf(self) -> float:
|
||||||
|
return max(self.workload_served / (time.time() - self.last_update), 0.0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workload_processing(self) -> float:
|
def workload_processing(self) -> float:
|
||||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||||
@@ -240,6 +242,7 @@ class ModelMetrics:
|
|||||||
self.workload_received = 0
|
self.workload_received = 0
|
||||||
self.workload_cancelled = 0
|
self.workload_cancelled = 0
|
||||||
self.workload_errored = 0
|
self.workload_errored = 0
|
||||||
|
self.last_update = time.time()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
+29
-25
@@ -45,34 +45,33 @@ class Metrics:
|
|||||||
self.model_metrics.workload_received += workload
|
self.model_metrics.workload_received += workload
|
||||||
self.model_metrics.requests_recieved.add(reqnum)
|
self.model_metrics.requests_recieved.add(reqnum)
|
||||||
self.model_metrics.requests_working.add(reqnum)
|
self.model_metrics.requests_working.add(reqnum)
|
||||||
|
|
||||||
def _request_end(
|
|
||||||
self, workload: float, req_response_time: float, reqnum: int
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
this function is called after a response from model API is received.
|
|
||||||
"""
|
|
||||||
self.model_metrics.workload_served += workload
|
|
||||||
self.model_metrics.workload_pending -= workload
|
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
|
||||||
self.model_metrics.cur_perf = workload / req_response_time
|
|
||||||
self.update_pending = True
|
self.update_pending = True
|
||||||
|
|
||||||
def _request_errored(self, workload: float, reqnum: int) -> None:
|
def _request_end(self, workload: float, reqnum: int) -> None:
|
||||||
|
"""
|
||||||
|
this function is called after handling of a request ends, regardless of the outcome
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_pending -= workload
|
||||||
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
|
|
||||||
|
def _request_success(self, workload: float) -> None:
|
||||||
|
"""
|
||||||
|
this function is called after a response from model API is received and forwarded.
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_served += workload
|
||||||
|
self.update_pending = True
|
||||||
|
|
||||||
|
def _request_errored(self, workload: float) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if model API returns an error
|
this function is called if model API returns an error
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_pending -= workload
|
|
||||||
self.model_metrics.workload_errored += workload
|
self.model_metrics.workload_errored += workload
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
|
||||||
|
|
||||||
def _request_canceled(self, workload: float, reqnum: int) -> None:
|
def _request_canceled(self, workload: float) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if client drops connection before model API has responded
|
this function is called if client drops connection before model API has responded
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_pending -= workload
|
|
||||||
self.model_metrics.workload_cancelled += workload
|
self.model_metrics.workload_cancelled += workload
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
|
||||||
|
|
||||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
while True:
|
while True:
|
||||||
@@ -80,10 +79,10 @@ class Metrics:
|
|||||||
elapsed = time.time() - self.last_metric_update
|
elapsed = time.time() - self.last_metric_update
|
||||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||||
self.__send_metrics_and_reset(elapsed)
|
self.__send_metrics_and_reset()
|
||||||
elif self.update_pending or elapsed > 10:
|
elif self.update_pending or elapsed > 10:
|
||||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||||
self.__send_metrics_and_reset(elapsed)
|
self.__send_metrics_and_reset()
|
||||||
|
|
||||||
def _model_loaded(self, max_throughput: float) -> None:
|
def _model_loaded(self, max_throughput: float) -> None:
|
||||||
self.system_metrics.model_loading_time = (
|
self.system_metrics.model_loading_time = (
|
||||||
@@ -98,13 +97,13 @@ class Metrics:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
def __send_metrics_and_reset(self, elapsed):
|
def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalaerData:
|
def compute_autoscaler_data() -> AutoScalaerData:
|
||||||
return AutoScalaerData(
|
return AutoScalaerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
||||||
cur_load=(self.model_metrics.workload_processing / elapsed),
|
cur_load=(self.model_metrics.workload_processing),
|
||||||
max_perf=self.model_metrics.max_throughput,
|
max_perf=self.model_metrics.max_throughput,
|
||||||
cur_perf=self.model_metrics.cur_perf,
|
cur_perf=self.model_metrics.cur_perf,
|
||||||
error_msg=self.model_metrics.error_msg or "",
|
error_msg=self.model_metrics.error_msg or "",
|
||||||
@@ -116,7 +115,7 @@ class Metrics:
|
|||||||
url=self.url,
|
url=self.url,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_data(report_addr: str) -> None:
|
def send_data(report_addr: str) -> bool:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -131,21 +130,26 @@ class Metrics:
|
|||||||
)
|
)
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
requests.post(full_path, json=asdict(data), timeout=1)
|
res = requests.post(full_path, json=asdict(data), timeout=1)
|
||||||
break
|
res.raise_for_status()
|
||||||
|
return True
|
||||||
except requests.Timeout:
|
except requests.Timeout:
|
||||||
log.debug(f"autoscaler status update timed out")
|
log.debug(f"autoscaler status update timed out")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"autoscaler status update failed with error: {e}")
|
log.debug(f"autoscaler status update failed with error: {e}")
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||||
|
log.debug(f"failed to send update through {report_addr}")
|
||||||
|
return False
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
send_data(report_addr)
|
success = send_data(report_addr)
|
||||||
|
if success is True:
|
||||||
|
break
|
||||||
self.update_pending = False
|
self.update_pending = False
|
||||||
self.model_metrics.reset()
|
self.model_metrics.reset()
|
||||||
self.system_metrics.reset()
|
self.system_metrics.reset()
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from collections import Counter
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from lib.data_types import AuthData, ApiPayload
|
from lib.data_types import AuthData, ApiPayload
|
||||||
@@ -120,9 +121,11 @@ class ClientState:
|
|||||||
self.url = worker_address
|
self.url = worker_address
|
||||||
url = urljoin(worker_address, self.worker_endpoint)
|
url = urljoin(worker_address, self.worker_endpoint)
|
||||||
self.status = ClientStatus.Generating
|
self.status = ClientStatus.Generating
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
self.infer_error.append(
|
self.infer_error.append(
|
||||||
|
|||||||
+2
-2
@@ -1,4 +1,4 @@
|
|||||||
aiohttp==3.10.1
|
aiohttp[speedups]==3.10.1
|
||||||
anyio~=4.4
|
anyio~=4.4
|
||||||
lib~=4.0
|
lib~=4.0
|
||||||
nltk~=3.9
|
nltk~=3.9
|
||||||
@@ -6,5 +6,5 @@ psutil~=6.0
|
|||||||
pycryptodome~=3.20
|
pycryptodome~=3.20
|
||||||
Requests~=2.32
|
Requests~=2.32
|
||||||
transformers~=4.52
|
transformers~=4.52
|
||||||
utils~=1.0
|
utils==1.0.*
|
||||||
hf_transfer>=0.1.9
|
hf_transfer>=0.1.9
|
||||||
|
|||||||
+54
-24
@@ -3,13 +3,12 @@
|
|||||||
set -e -o pipefail
|
set -e -o pipefail
|
||||||
|
|
||||||
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
||||||
|
SERVER_DIR="$WORKSPACE_DIR/worker"
|
||||||
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
|
||||||
ENV_PATH="$WORKSPACE_DIR/worker-env"
|
ENV_PATH="$WORKSPACE_DIR/worker-env"
|
||||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
|
||||||
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
|
||||||
USE_SSL="${USE_SSL:-true}"
|
USE_SSL="${USE_SSL:-true}"
|
||||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||||
mkdir -p "$WORKSPACE_DIR"
|
mkdir -p "$WORKSPACE_DIR"
|
||||||
@@ -22,49 +21,82 @@ function echo_var(){
|
|||||||
echo "$1: ${!1}"
|
echo "$1: ${!1}"
|
||||||
}
|
}
|
||||||
|
|
||||||
[ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1
|
# Updated validation - BACKEND no longer required, but MODEL_LOG still is
|
||||||
[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1
|
[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1
|
||||||
[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1
|
[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1
|
||||||
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && echo "For comfyui backends, COMFY_MODEL must be set!" && exit 1
|
|
||||||
|
|
||||||
|
echo "start_server.sh - SDK Worker Version"
|
||||||
echo "start_server.sh"
|
|
||||||
date
|
date
|
||||||
|
|
||||||
echo_var BACKEND
|
|
||||||
echo_var REPORT_ADDR
|
echo_var REPORT_ADDR
|
||||||
echo_var WORKER_PORT
|
echo_var WORKER_PORT
|
||||||
echo_var WORKSPACE_DIR
|
echo_var WORKSPACE_DIR
|
||||||
echo_var SERVER_DIR
|
|
||||||
echo_var ENV_PATH
|
echo_var ENV_PATH
|
||||||
echo_var DEBUG_LOG
|
echo_var DEBUG_LOG
|
||||||
echo_var PYWORKER_LOG
|
echo_var PYWORKER_LOG
|
||||||
echo_var MODEL_LOG
|
echo_var MODEL_LOG
|
||||||
|
echo_var MODEL_SERVER_URL
|
||||||
|
echo_var PYWORKER_REPO
|
||||||
|
echo_var PYWORKER_REF
|
||||||
|
|
||||||
env | grep _ >> /etc/environment;
|
# Populate /etc/environment with quoted values
|
||||||
|
if ! grep -q "VAST" /etc/environment; then
|
||||||
|
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||||
|
name=${line%%=*}
|
||||||
|
value=${line#*=}
|
||||||
|
printf '%s="%s"\n' "$name" "$value"
|
||||||
|
done > /etc/environment
|
||||||
|
fi
|
||||||
|
|
||||||
if [ ! -d "$ENV_PATH" ]
|
if [ ! -d "$ENV_PATH" ]
|
||||||
then
|
then
|
||||||
echo "setting up venv"
|
echo "setting up venv"
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
if ! which uv; then
|
||||||
source ~/.local/bin/env
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
source ~/.local/bin/env
|
||||||
|
fi
|
||||||
|
|
||||||
uv venv --managed-python "$WORKSPACE_DIR/worker-env" -p 3.10
|
if [[ ! -d $SERVER_DIR ]]; then
|
||||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
echo "Cloning worker repository..."
|
||||||
|
git clone --depth=1 "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
uv pip install -r vast-pyworker/requirements.txt
|
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||||
|
echo "Checking out ref: $PYWORKER_REF"
|
||||||
|
(
|
||||||
|
cd "$SERVER_DIR"
|
||||||
|
git fetch --depth=1 origin "$PYWORKER_REF"
|
||||||
|
git checkout "$PYWORKER_REF"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
||||||
|
source "$ENV_PATH/bin/activate"
|
||||||
|
|
||||||
|
# Install vast-sdk from server-side-sdk branch
|
||||||
|
echo "Installing vast-sdk from GitHub (server-side-sdk branch)..."
|
||||||
|
uv pip install "git+https://github.com/vast-ai/vast-sdk.git@server-side-sdk"
|
||||||
|
|
||||||
|
# Install requirements from worker repo if they exist
|
||||||
|
if [ -f "${SERVER_DIR}/requirements.txt" ]; then
|
||||||
|
echo "Installing additional dependencies from requirements.txt..."
|
||||||
|
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
||||||
|
fi
|
||||||
|
|
||||||
touch ~/.no_auto_tmux
|
touch ~/.no_auto_tmux
|
||||||
else
|
else
|
||||||
source ~/.local/bin/env
|
[[ -f ~/.local/bin/env ]] && source ~/.local/bin/env
|
||||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||||
echo "environment activated"
|
echo "environment activated"
|
||||||
echo "venv: $VIRTUAL_ENV"
|
echo "venv: $VIRTUAL_ENV"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
|
# Check that worker.py exists
|
||||||
|
if [ ! -f "$SERVER_DIR/worker.py" ]; then
|
||||||
|
echo "ERROR: worker.py not found in $SERVER_DIR"
|
||||||
|
echo "Please ensure your PYWORKER_REPO contains a worker.py file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
if [ "$USE_SSL" = true ]; then
|
if [ "$USE_SSL" = true ]; then
|
||||||
|
|
||||||
@@ -102,9 +134,6 @@ EOF
|
|||||||
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
||||||
|
|
||||||
cd "$SERVER_DIR"
|
cd "$SERVER_DIR"
|
||||||
@@ -115,5 +144,6 @@ echo "launching PyWorker server"
|
|||||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||||
|
|
||||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
# Launch the SDK-based worker instead of the old backend system
|
||||||
echo "launching PyWorker server done"
|
(python3 worker.py |& tee -a "$PYWORKER_LOG") &
|
||||||
|
echo "launching PyWorker server done"
|
||||||
+20
-2
@@ -16,6 +16,24 @@ class Endpoint:
|
|||||||
Utility class for handling endpoint operations.
|
Utility class for handling endpoint operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_autoscaler_server_url(instance: str) -> str:
|
||||||
|
endpoints = {
|
||||||
|
"alpha": "run-alpha",
|
||||||
|
"candidate": "run-candidate",
|
||||||
|
"prod": "run",
|
||||||
|
}
|
||||||
|
return f"https://{endpoints[instance]}.vast.ai/"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_server_url(instance: str) -> str:
|
||||||
|
endpoints = {
|
||||||
|
"alpha": "alpha",
|
||||||
|
"candidate": "candidate",
|
||||||
|
"prod": "console",
|
||||||
|
}
|
||||||
|
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(
|
def get_endpoint_api_key(
|
||||||
endpoint_name: str, account_api_key: str, instance: str
|
endpoint_name: str, account_api_key: str, instance: str
|
||||||
@@ -30,13 +48,13 @@ class Endpoint:
|
|||||||
Returns:
|
Returns:
|
||||||
Endpoint API key if successful, None otherwise
|
Endpoint API key if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
|
|
||||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
|
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
import tempfile
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_cert_file_path():
|
||||||
|
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||||
|
response = requests.get(cert_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
# Use a temporary file that is not deleted on close
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
return f.name
|
||||||
@@ -0,0 +1,210 @@
|
|||||||
|
# ComfyUI PyWorker
|
||||||
|
|
||||||
|
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
|
||||||
|
|
||||||
|
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [ComfyUI API Wrapper](https://github.com/ai-dock/comfyui-api-wrapper).
|
||||||
|
|
||||||
|
A docker image is provided but you may use any if the above requirements are met.
|
||||||
|
|
||||||
|
## Benchmarking
|
||||||
|
|
||||||
|
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
|
||||||
|
|
||||||
|
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
|
||||||
|
|
||||||
|
| Environment Variable | Default Value | Description |
|
||||||
|
| -------------------- | ------------- | ----------- |
|
||||||
|
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
|
||||||
|
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
|
||||||
|
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
|
||||||
|
|
||||||
|
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||||
|
|
||||||
|
### Calibrating Benchmark Duration
|
||||||
|
|
||||||
|
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||||
|
|
||||||
|
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Measure it/sec on your reference machine
|
||||||
|
# RTX 4090 typically achieves ~43 it/sec with SD1.5
|
||||||
|
|
||||||
|
# 2. Calculate required steps
|
||||||
|
# 90 seconds × 43 it/sec = 3870 steps
|
||||||
|
|
||||||
|
# 3. Configure benchmark
|
||||||
|
export BENCHMARK_TEST_STEPS=3870
|
||||||
|
|
||||||
|
# 4. Machines completing significantly slower than 90s indicate hardware issues
|
||||||
|
```
|
||||||
|
|
||||||
|
**Performance expectations:**
|
||||||
|
- Benchmark duration should remain consistent across identical GPU models
|
||||||
|
- Significant variation (>20%) may indicate thermal, power, or configuration issues
|
||||||
|
|
||||||
|
## Endpoint
|
||||||
|
|
||||||
|
The worker provides a single endpoint:
|
||||||
|
|
||||||
|
- `/generate/sync`: Processes ComfyUI workflows using either predefined modifiers or custom workflow JSON
|
||||||
|
|
||||||
|
## Request Format
|
||||||
|
|
||||||
|
The worker accepts requests in the following format. Choose either modifier mode OR custom workflow mode:
|
||||||
|
|
||||||
|
**Modifier Mode:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"request_id": "uuid-string", // optional - UUID generated if not provided
|
||||||
|
"modifier": "RawWorkflow",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": "a beautiful landscape",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"steps": 20,
|
||||||
|
"seed": 123456789
|
||||||
|
},
|
||||||
|
"s3": { ... }, // optional
|
||||||
|
"webhook": { ... } // optional
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Custom Workflow Mode:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"request_id": "uuid-string", // optional - UUID generated if not provided
|
||||||
|
"workflow_json": {
|
||||||
|
// Complete ComfyUI workflow JSON
|
||||||
|
},
|
||||||
|
"s3": { ... }, // optional
|
||||||
|
"webhook": { ... } // optional
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Request Fields
|
||||||
|
|
||||||
|
### Required Fields
|
||||||
|
|
||||||
|
- **`input`**: Contains the main workflow data
|
||||||
|
- **`input.request_id`**: Unique identifier for the request
|
||||||
|
|
||||||
|
### Workflow Mode (Choose One)
|
||||||
|
|
||||||
|
You must provide either `modifier` OR `workflow_json`, but not both:
|
||||||
|
|
||||||
|
#### Option 1: Modifier Mode
|
||||||
|
- **`input.modifier`**: Name of the predefined workflow modifier (e.g., "Text2Image")
|
||||||
|
- **`input.modifications`**: Parameters to pass to the modifier
|
||||||
|
|
||||||
|
#### Option 2: Custom Workflow Mode
|
||||||
|
- **`input.workflow_json`**: Complete ComfyUI workflow JSON
|
||||||
|
|
||||||
|
### Optional Fields
|
||||||
|
|
||||||
|
- **`input.s3`**: S3 configuration for file storage
|
||||||
|
- **`input.webhook`**: Webhook configuration for notifications
|
||||||
|
|
||||||
|
These configurations can be provided in the request JSON or via environment variables. Request-level configuration takes precedence over environment variables.
|
||||||
|
|
||||||
|
#### S3 Configuration
|
||||||
|
|
||||||
|
**Via Request JSON:**
|
||||||
|
```json
|
||||||
|
"s3": {
|
||||||
|
"access_key_id": "your-s3-access-key",
|
||||||
|
"secret_access_key": "your-s3-secret-access-key",
|
||||||
|
"endpoint_url": "https://my-endpoint.backblaze.com",
|
||||||
|
"bucket_name": "your-bucket",
|
||||||
|
"region": "us-east-1"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Via Environment Variables:**
|
||||||
|
```bash
|
||||||
|
S3_ACCESS_KEY_ID=your-key
|
||||||
|
S3_SECRET_ACCESS_KEY=your-secret
|
||||||
|
S3_BUCKET_NAME=your-bucket
|
||||||
|
S3_ENDPOINT_URL=https://s3.amazonaws.com
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Webhook Configuration
|
||||||
|
|
||||||
|
**Via Request JSON:**
|
||||||
|
```json
|
||||||
|
"webhook": {
|
||||||
|
"url": "your-webhook-url",
|
||||||
|
"extra_params": {
|
||||||
|
"custom_field": "value"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Via Environment Variables:**
|
||||||
|
```bash
|
||||||
|
WEBHOOK_URL=https://your-webhook.com # Default webhook URL
|
||||||
|
WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Basic Text-to-Image (Modifier Mode)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"modifier": "Text2Image",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": "a cat sitting on a windowsill",
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"steps": 20,
|
||||||
|
"seed": 42
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Workflow Mode
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"request_id": "67890", // optional - using custom ID for tracking
|
||||||
|
"workflow_json": {
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"seed": 42,
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 8,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": ["4", 0],
|
||||||
|
"positive": ["6", 0],
|
||||||
|
"negative": ["7", 0],
|
||||||
|
"latent_image": ["5", 0]
|
||||||
|
},
|
||||||
|
"class_type": "KSampler"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Client Libraries
|
||||||
|
|
||||||
|
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import random
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from lib.test_utils import print_truncate_res
|
||||||
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
from .data_types import count_workload
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def call_text2image_workflow(
|
||||||
|
endpoint_group_name: str, api_key: str, server_url: str
|
||||||
|
) -> None:
|
||||||
|
"""Simple Text2Image using the new modifier-based approach"""
|
||||||
|
|
||||||
|
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
|
||||||
|
"""Helper function for making requests with consistent error handling"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
timeout=timeout,
|
||||||
|
verify=verify
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except requests.exceptions.HTTPError as http_err:
|
||||||
|
log.error(f"HTTP error occurred during {context}: {http_err}")
|
||||||
|
log.error(f"Status Code: {response.status_code}")
|
||||||
|
log.error("Response content:", response.text)
|
||||||
|
return None
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
log.error(f"Timeout occurred during {context}: {url}")
|
||||||
|
return None
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
log.error(f"Connection error occurred during {context}: {url}")
|
||||||
|
return None
|
||||||
|
except json.JSONDecodeError as json_err:
|
||||||
|
log.error(f"Failed to decode JSON response during {context}: {json_err}")
|
||||||
|
if 'response' in locals():
|
||||||
|
print("Response content:", response.text)
|
||||||
|
return None
|
||||||
|
except Exception as err:
|
||||||
|
log.error(f"An unexpected error occurred during {context}: {err}")
|
||||||
|
if 'response' in locals():
|
||||||
|
log.error("Response content (if available):", response.text)
|
||||||
|
return None
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/generate/sync"
|
||||||
|
|
||||||
|
# This worker has concurrency = 1. All workloads have cost value 1.0
|
||||||
|
COST = count_workload()
|
||||||
|
|
||||||
|
# Route to get worker URL
|
||||||
|
route_payload = {
|
||||||
|
"endpoint": endpoint_group_name,
|
||||||
|
"api_key": api_key,
|
||||||
|
"cost": COST,
|
||||||
|
}
|
||||||
|
|
||||||
|
# First request - get routing information
|
||||||
|
route_response = make_request(
|
||||||
|
url=urljoin(server_url, "/route/"),
|
||||||
|
payload=route_payload,
|
||||||
|
timeout=4,
|
||||||
|
context="route request"
|
||||||
|
)
|
||||||
|
|
||||||
|
if route_response is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "url" not in route_response or not route_response["url"]:
|
||||||
|
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "status" in route_response:
|
||||||
|
print(f"Autoscaler status: {route_response['status']}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract data from route response
|
||||||
|
url = route_response["url"]
|
||||||
|
auth_data = dict(
|
||||||
|
signature=route_response["signature"],
|
||||||
|
cost=route_response["cost"],
|
||||||
|
endpoint=route_response["endpoint"],
|
||||||
|
reqnum=route_response["reqnum"],
|
||||||
|
url=route_response["url"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the payload for the worker request
|
||||||
|
worker_payload = {
|
||||||
|
"input": {
|
||||||
|
"request_id": str(uuid.uuid4()),
|
||||||
|
"modifier": "Text2Image",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": "a beautiful landscape with mountains and lakes",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"steps": 20,
|
||||||
|
"seed": random.randint(0, 2**32 - 1)
|
||||||
|
},
|
||||||
|
"workflow_json": {} # Empty since using modifier approach
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req_data = dict(payload=worker_payload, auth_data=auth_data)
|
||||||
|
worker_url = urljoin(url, WORKER_ENDPOINT)
|
||||||
|
print(f"url: {worker_url}")
|
||||||
|
|
||||||
|
# Second request - call the worker endpoint
|
||||||
|
worker_response = make_request(
|
||||||
|
url=worker_url,
|
||||||
|
payload=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
context="worker request"
|
||||||
|
)
|
||||||
|
|
||||||
|
return worker_response
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
|
args = test_args.parse_args()
|
||||||
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
instance=args.instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
if endpoint_api_key:
|
||||||
|
result = call_text2image_workflow(
|
||||||
|
api_key=endpoint_api_key,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
server_url=args.server_url,
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
log.error("Text2Image workflow failed")
|
||||||
|
else:
|
||||||
|
print(result)
|
||||||
|
else:
|
||||||
|
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import dataclasses
|
||||||
|
from typing import Dict, Any
|
||||||
|
from functools import cache
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
|
|
||||||
|
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||||
|
test_prompts = f.readlines()
|
||||||
|
|
||||||
|
def count_workload() -> float:
|
||||||
|
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||||
|
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
|
||||||
|
return 100.0
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ComfyWorkflowData(ApiPayload):
|
||||||
|
input: dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls):
|
||||||
|
"""
|
||||||
|
Use the variables available to simulate workflows of the required running time
|
||||||
|
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||||
|
"""
|
||||||
|
test_prompt = random.choice(test_prompts).rstrip()
|
||||||
|
return cls(
|
||||||
|
input={
|
||||||
|
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||||
|
"modifier": "Text2Image",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": test_prompt,
|
||||||
|
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
|
||||||
|
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
|
||||||
|
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
|
||||||
|
"seed": random.randint(0, sys.maxsize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
# input is already a dict, just return it wrapped in the expected structure
|
||||||
|
return {"input": self.input}
|
||||||
|
|
||||||
|
def count_workload(self) -> float:
|
||||||
|
return count_workload()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
|
||||||
|
# Extract required fields
|
||||||
|
if "input" not in json_msg:
|
||||||
|
raise JsonDataException({"input": "missing parameter"})
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
input=json_msg["input"]
|
||||||
|
)
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
|
||||||
|
stardew valley, fine details
|
||||||
|
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
|
||||||
|
realistic futuristic city-downtown with short buildings, sunset
|
||||||
|
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
|
||||||
|
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
|
||||||
|
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
|
||||||
|
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
|
||||||
|
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
|
||||||
|
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
|
||||||
|
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
|
||||||
|
Pope Francis wearing biker (leather jacket), a masterpiece
|
||||||
|
Luke Skywalker ordering a burger and fries from the Death Star canteen.
|
||||||
|
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
|
||||||
|
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
|
||||||
|
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
|
||||||
|
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
|
||||||
|
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
|
||||||
|
london luxurious interior living-room, light walls
|
||||||
|
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
|
||||||
|
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
|
||||||
|
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
|
||||||
|
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
|
||||||
|
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
|
||||||
|
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
|
||||||
|
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
|
||||||
|
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
|
||||||
|
the street of amedieval fantasy town, at dawn, dark, highly detailed
|
||||||
|
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
|
||||||
|
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
|
||||||
|
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import dataclasses
|
||||||
|
import base64
|
||||||
|
from typing import Optional, Union, Type
|
||||||
|
|
||||||
|
from aiohttp import web, ClientResponse
|
||||||
|
|
||||||
|
from lib.backend import Backend, LogAction
|
||||||
|
from lib.data_types import EndpointHandler
|
||||||
|
from lib.server import start_server
|
||||||
|
from .data_types import ComfyWorkflowData
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
|
||||||
|
|
||||||
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
|
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||||
|
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
# Check if the response is actually streaming based on response headers/content-type
|
||||||
|
is_streaming_response = (
|
||||||
|
model_response.content_type == "text/event-stream"
|
||||||
|
or model_response.content_type == "application/x-ndjson"
|
||||||
|
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||||
|
or "stream" in model_response.content_type.lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_streaming_response:
|
||||||
|
log.debug("Detected streaming response...")
|
||||||
|
res = web.StreamResponse()
|
||||||
|
res.content_type = model_response.content_type
|
||||||
|
await res.prepare(client_request)
|
||||||
|
async for chunk in model_response.content:
|
||||||
|
await res.write(chunk)
|
||||||
|
await res.write_eof()
|
||||||
|
log.debug("Done streaming response")
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
log.debug("Detected non-streaming response...")
|
||||||
|
content = await model_response.read()
|
||||||
|
return web.Response(
|
||||||
|
body=content,
|
||||||
|
status=model_response.status,
|
||||||
|
content_type=model_response.content_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/generate/sync"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
return f"{MODEL_SERVER_URL}/health"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[ComfyWorkflowData]:
|
||||||
|
return ComfyWorkflowData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> ComfyWorkflowData:
|
||||||
|
return ComfyWorkflowData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
return await generate_client_response(client_request, model_response)
|
||||||
|
|
||||||
|
|
||||||
|
backend = Backend(
|
||||||
|
model_server_url=MODEL_SERVER_URL,
|
||||||
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
|
allow_parallel_requests=False,
|
||||||
|
benchmark_handler=ComfyWorkflowHandler(
|
||||||
|
benchmark_runs=3, benchmark_words=100
|
||||||
|
),
|
||||||
|
log_actions=[
|
||||||
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||||
|
(LogAction.Info, "Downloading:"),
|
||||||
|
*[
|
||||||
|
(LogAction.ModelError, error_msg)
|
||||||
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ping(_):
|
||||||
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
|
||||||
|
web.get("/ping", handle_ping),
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_server(backend, routes)
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
from lib.test_utils import test_load_cmd, test_args
|
||||||
|
from .data_types import ComfyWorkflowData
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/generate/sync"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||||
@@ -5,6 +5,7 @@ import requests
|
|||||||
|
|
||||||
from lib.test_utils import print_truncate_res
|
from lib.test_utils import print_truncate_res
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||||
@@ -51,6 +52,7 @@ def call_default_workflow(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from lib.server import start_server
|
|||||||
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
||||||
|
|
||||||
|
|
||||||
MODEL_SERVER_URL = "http://0.0.0.0:38188"
|
MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
|
||||||
|
|
||||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from urllib.parse import urljoin
|
|||||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -90,9 +91,13 @@ class APIClient:
|
|||||||
|
|
||||||
# Make the request using the specified method
|
# Make the request using the specified method
|
||||||
if method.upper() == "POST":
|
if method.upper() == "POST":
|
||||||
response = requests.post(url, json=req_data, stream=stream)
|
response = requests.post(
|
||||||
|
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
elif method.upper() == "GET":
|
elif method.upper() == "GET":
|
||||||
response = requests.get(url, params=req_data, stream=stream)
|
response = requests.get(
|
||||||
|
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
@@ -562,7 +567,7 @@ def main():
|
|||||||
client = APIClient(
|
client = APIClient(
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
api_key=args.api_key,
|
api_key=args.api_key,
|
||||||
server_url=args.server_url,
|
server_url=Endpoint.get_autoscaler_server_url(args.instance),
|
||||||
endpoint_api_key=endpoint_api_key,
|
endpoint_api_key=endpoint_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import requests
|
import requests
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
|||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
print(f"url: {url}")
|
print(f"url: {url}")
|
||||||
response = requests.post(url, json=req_data)
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
res = response.json()
|
res = response.json()
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
Reference in New Issue
Block a user