Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 7437028cb2 | |||
| 02c8307af7 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 9f5a432513 | |||
| e09f1fa953 | |||
| ba6f1c2e4b | |||
| 944f83fc03 | |||
| 298590fb88 | |||
| 814c3acd4c | |||
| 22bca74087 | |||
| f56bbc0ebe |
+33
-92
@@ -5,7 +5,7 @@ import base64
|
|||||||
import subprocess
|
import subprocess
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
@@ -30,7 +30,7 @@ from lib.data_types import (
|
|||||||
BenchmarkResult
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.1.0"
|
VERSION = "0.2.0"
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
@@ -47,7 +47,7 @@ class Backend:
|
|||||||
This class is responsible for:
|
This class is responsible for:
|
||||||
1. Tailing logs and updating load time metrics
|
1. Tailing logs and updating load time metrics
|
||||||
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
||||||
sending the request. It also updates metrics as it makes those requests.
|
sending the request. It also updates metrics as it makes those requests.
|
||||||
3. Running a benchmark from an EndpointHandler
|
3. Running a benchmark from an EndpointHandler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -66,19 +66,21 @@ class Backend:
|
|||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
|
report_addr: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||||
|
)
|
||||||
|
mtoken: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
self.metrics._set_version(self.version)
|
self.metrics._set_version(self.version)
|
||||||
|
self.metrics._set_mtoken(self.mtoken)
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
|
|
||||||
# NEW: FIFO queue + worker count
|
|
||||||
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
|
|
||||||
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
|
|
||||||
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||||
if self._pubkey is None:
|
if self._pubkey is None:
|
||||||
@@ -96,22 +98,6 @@ class Backend:
|
|||||||
timeout = ClientTimeout(total=None)
|
timeout = ClientTimeout(total=None)
|
||||||
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||||
|
|
||||||
async def _worker(self):
|
|
||||||
while True:
|
|
||||||
handler, request, fut = await self.request_queue.get()
|
|
||||||
try:
|
|
||||||
# Skip if already cancelled while waiting in the queue
|
|
||||||
if fut.cancelled():
|
|
||||||
continue
|
|
||||||
res = await self.__process_enqueued_request(handler, request)
|
|
||||||
if not fut.cancelled():
|
|
||||||
fut.set_result(res)
|
|
||||||
except Exception as e:
|
|
||||||
if not fut.cancelled():
|
|
||||||
fut.set_exception(e)
|
|
||||||
finally:
|
|
||||||
self.request_queue.task_done()
|
|
||||||
|
|
||||||
def create_handler(
|
def create_handler(
|
||||||
self,
|
self,
|
||||||
handler: EndpointHandler[ApiPayload_T],
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
@@ -125,59 +111,26 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
report_addr = self.report_addr.rstrip("/")
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||||
log.debug("public key:")
|
try:
|
||||||
log.debug(result)
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
key = None
|
log.debug("public key:")
|
||||||
for _ in range(5):
|
log.debug(result)
|
||||||
try:
|
key = RSA.import_key(result)
|
||||||
key = RSA.import_key(result)
|
if key is not None:
|
||||||
break
|
return key
|
||||||
except ValueError as e:
|
except (ValueError , subprocess.CalledProcessError) as e:
|
||||||
log.debug(f"Error downloading key: {e}")
|
log.debug(f"Error downloading key: {e}")
|
||||||
time.sleep(15)
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
if key is None:
|
|
||||||
self._total_pubkey_fetch_errors += 1
|
|
||||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
|
||||||
return key
|
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
handler: EndpointHandler[ApiPayload_T],
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
request: web.Request,
|
request: web.Request,
|
||||||
) -> Union[web.Response, web.StreamResponse]:
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
"""use this function to enqueue requests for FIFO processing"""
|
"""use this function to forward requests to the model endpoint"""
|
||||||
loop = get_running_loop()
|
|
||||||
fut: asyncio.Future = loop.create_future()
|
|
||||||
|
|
||||||
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
|
|
||||||
cancel_watch = create_task(request.wait_for_disconnection())
|
|
||||||
def _cancel_if_disconnected(_):
|
|
||||||
if not fut.done():
|
|
||||||
fut.cancel()
|
|
||||||
cancel_watch.add_done_callback(_cancel_if_disconnected)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.request_queue.put((handler, request, fut))
|
|
||||||
return await fut
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Propagate cancellation to ensure aiohttp doesn't expect a response body
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# Best-effort cleanup of the watcher
|
|
||||||
cancel_watch.cancel()
|
|
||||||
|
|
||||||
async def __process_enqueued_request(
|
|
||||||
self,
|
|
||||||
handler: EndpointHandler[ApiPayload_T],
|
|
||||||
request: web.Request,
|
|
||||||
) -> Union[web.Response, web.StreamResponse]:
|
|
||||||
"""
|
|
||||||
This contains the original __handle_request logic and is invoked by workers,
|
|
||||||
ensuring FIFO execution via asyncio.Queue.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
auth_data, payload = handler.get_data_from_request(data)
|
auth_data, payload = handler.get_data_from_request(data)
|
||||||
@@ -185,11 +138,8 @@ class Backend:
|
|||||||
return web.json_response(data=e.message, status=422)
|
return web.json_response(data=e.message, status=422)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
|
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
request_metrics: RequestMetrics = RequestMetrics(
|
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||||
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
await request.wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
@@ -230,8 +180,6 @@ class Backend:
|
|||||||
acquired = False
|
acquired = False
|
||||||
try:
|
try:
|
||||||
self.metrics._request_start(request_metrics)
|
self.metrics._request_start(request_metrics)
|
||||||
|
|
||||||
# Preserve existing semaphore behavior for serializing requests when requested
|
|
||||||
if self.allow_parallel_requests is False:
|
if self.allow_parallel_requests is False:
|
||||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||||
await self.sem.acquire()
|
await self.sem.acquire()
|
||||||
@@ -241,7 +189,6 @@ class Backend:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||||
|
|
||||||
done, pending = await wait(
|
done, pending = await wait(
|
||||||
[
|
[
|
||||||
create_task(make_request()),
|
create_task(make_request()),
|
||||||
@@ -309,14 +256,8 @@ class Backend:
|
|||||||
self.backend_errored(str(e))
|
self.backend_errored(str(e))
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
# Start the FIFO workers alongside existing loops
|
|
||||||
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
|
|
||||||
await gather(
|
await gather(
|
||||||
self.__read_logs(),
|
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||||
self.metrics._send_metrics_loop(),
|
|
||||||
self.__healthcheck(),
|
|
||||||
self.metrics._send_delete_requests_loop(),
|
|
||||||
*worker_tasks,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
@@ -348,7 +289,7 @@ class Backend:
|
|||||||
message = {
|
message = {
|
||||||
key: value
|
key: value
|
||||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
if key != "signature"
|
if key != "signature" and key != "__request_id"
|
||||||
}
|
}
|
||||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -358,7 +299,7 @@ class Backend:
|
|||||||
elif message in self.msg_history:
|
elif message in self.msg_history:
|
||||||
log.debug(f"message: {message} already in message history")
|
log.debug(f"message: {message} already in message history")
|
||||||
return False
|
return False
|
||||||
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
self.msg_history.append(message)
|
self.msg_history.append(message)
|
||||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
@@ -377,10 +318,10 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
_ = await self.__call_api(
|
# _ = await self.__call_api(
|
||||||
handler=self.benchmark_handler, payload=payload
|
# handler=self.benchmark_handler, payload=payload
|
||||||
)
|
# )
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -455,7 +396,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
await sleep(5)
|
# await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
|
|||||||
+6
-4
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
|
|||||||
class AuthData:
|
class AuthData:
|
||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
signature: str
|
|
||||||
cost: str
|
cost: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
reqnum: int
|
reqnum: int
|
||||||
url: str
|
|
||||||
request_idx: int
|
request_idx: int
|
||||||
|
signature: str
|
||||||
|
url: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||||
@@ -190,11 +190,12 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, expected: float | None) -> None:
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
self.model_loading_time = None
|
if self.model_loading_time == expected:
|
||||||
|
self.model_loading_time = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -285,6 +286,7 @@ class AutoScalerData:
|
|||||||
"""Data that is reported to autoscaler"""
|
"""Data that is reported to autoscaler"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
|
mtoken: str
|
||||||
version: str
|
version: str
|
||||||
loadtime: float
|
loadtime: float
|
||||||
cur_load: float
|
cur_load: float
|
||||||
|
|||||||
+33
-9
@@ -28,6 +28,7 @@ def get_url() -> str:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Metrics:
|
class Metrics:
|
||||||
version: str = "0"
|
version: str = "0"
|
||||||
|
mtoken: str = ""
|
||||||
last_metric_update: float = 0.0
|
last_metric_update: float = 0.0
|
||||||
last_request_served: float = 0.0
|
last_request_served: float = 0.0
|
||||||
update_pending: bool = False
|
update_pending: bool = False
|
||||||
@@ -142,12 +143,16 @@ class Metrics:
|
|||||||
def _set_version(self, version: str) -> None:
|
def _set_version(self, version: str) -> None:
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
|
def _set_mtoken(self, mtoken: str) -> None:
|
||||||
|
self.mtoken = mtoken
|
||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
async def __send_delete_requests_and_reset(self):
|
async def __send_delete_requests_and_reset(self):
|
||||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||||
data = {
|
data = {
|
||||||
"worker_id": self.id,
|
"worker_id": self.id,
|
||||||
|
"mtoken": self.mtoken,
|
||||||
"request_idxs": idxs,
|
"request_idxs": idxs,
|
||||||
"success": success_flag,
|
"success": success_flag,
|
||||||
}
|
}
|
||||||
@@ -180,6 +185,10 @@ class Metrics:
|
|||||||
return # nothing to do
|
return # nothing to do
|
||||||
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
|
# TODO: Add a Redis subscriber queue for delete_requests
|
||||||
|
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||||
|
# Patch: ignore the Redis API report_addr
|
||||||
|
continue
|
||||||
sent_success = True
|
sent_success = True
|
||||||
sent_failed = True
|
sent_failed = True
|
||||||
|
|
||||||
@@ -200,11 +209,14 @@ class Metrics:
|
|||||||
|
|
||||||
async def __send_metrics_and_reset(self):
|
async def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
|
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
|
mtoken=self.mtoken,
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
loadtime=(loadtime_snapshot or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
cur_load=self.model_metrics.cur_load,
|
cur_load=self.model_metrics.cur_load,
|
||||||
rej_load=self.model_metrics.workload_rejected,
|
rej_load=self.model_metrics.workload_rejected,
|
||||||
@@ -222,17 +234,25 @@ class Metrics:
|
|||||||
|
|
||||||
async def send_data(report_addr: str) -> bool:
|
async def send_data(report_addr: str) -> bool:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
log_data = asdict(data)
|
||||||
|
def obfuscate(secret: str) -> str:
|
||||||
|
if secret is None:
|
||||||
|
return ""
|
||||||
|
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||||
|
|
||||||
|
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
f"sending data to autoscaler",
|
f"sending data to autoscaler",
|
||||||
f"{json.dumps((asdict(data)), indent=2)}",
|
f"{json.dumps(log_data, indent=2)}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
session = await self.http()
|
session = await self.http()
|
||||||
@@ -252,11 +272,15 @@ class Metrics:
|
|||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
|
sent = False
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
success = await send_data(report_addr)
|
if await send_data(report_addr):
|
||||||
if success is True:
|
sent = True
|
||||||
break
|
break
|
||||||
self.update_pending = False
|
|
||||||
self.model_metrics.reset()
|
if sent:
|
||||||
self.system_metrics.reset()
|
# clear the one-shot loadtime only if we actually sent *this* value
|
||||||
self.last_metric_update = time.time()
|
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||||
|
self.update_pending = False
|
||||||
|
self.model_metrics.reset()
|
||||||
|
self.last_metric_update = time.time()
|
||||||
|
|||||||
+1
-1
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
|
|||||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
|
||||||
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
|
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||||
USE_SSL="${USE_SSL:-true}"
|
USE_SSL="${USE_SSL:-true}"
|
||||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||||
mkdir -p "$WORKSPACE_DIR"
|
mkdir -p "$WORKSPACE_DIR"
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ def call_text2image_workflow(
|
|||||||
endpoint=route_response["endpoint"],
|
endpoint=route_response["endpoint"],
|
||||||
reqnum=route_response["reqnum"],
|
reqnum=route_response["reqnum"],
|
||||||
url=route_response["url"],
|
url=route_response["url"],
|
||||||
|
request_idx=route_response["request_idx"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the payload for the worker request
|
# Build the payload for the worker request
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
endpoint=message["endpoint"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
|
request_idx=message["request_idx"],
|
||||||
)
|
)
|
||||||
workflow = {
|
workflow = {
|
||||||
"3": {
|
"3": {
|
||||||
|
|||||||
Reference in New Issue
Block a user