Compare commits

..

1 Commits

Author SHA1 Message Date
Lucas Armand 7d3be849d9 Handle errors from model for comfyui-json 2025-10-08 12:00:45 -07:00
3 changed files with 54 additions and 66 deletions
+4 -5
View File
@@ -45,7 +45,6 @@ class Metrics:
self.model_metrics.workload_received += workload self.model_metrics.workload_received += workload
self.model_metrics.requests_recieved.add(reqnum) self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum) self.model_metrics.requests_working.add(reqnum)
self.update_pending = True
def _request_end(self, workload: float, reqnum: int) -> None: def _request_end(self, workload: float, reqnum: int) -> None:
""" """
@@ -79,10 +78,10 @@ class Metrics:
elapsed = time.time() - self.last_metric_update elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10: if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait") log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset() self.__send_metrics_and_reset(elapsed)
elif self.update_pending or elapsed > 10: elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset() self.__send_metrics_and_reset(elapsed)
def _model_loaded(self, max_throughput: float) -> None: def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = ( self.system_metrics.model_loading_time = (
@@ -97,13 +96,13 @@ class Metrics:
#######################################Private####################################### #######################################Private#######################################
def __send_metrics_and_reset(self): def __send_metrics_and_reset(self, elapsed):
def compute_autoscaler_data() -> AutoScalaerData: def compute_autoscaler_data() -> AutoScalaerData:
return AutoScalaerData( return AutoScalaerData(
id=self.id, id=self.id,
loadtime=(self.system_metrics.model_loading_time or 0.0), loadtime=(self.system_metrics.model_loading_time or 0.0),
cur_load=(self.model_metrics.workload_processing), cur_load=(self.model_metrics.workload_processing / elapsed),
max_perf=self.model_metrics.max_throughput, max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf, cur_perf=self.model_metrics.cur_perf,
error_msg=self.model_metrics.error_msg or "", error_msg=self.model_metrics.error_msg or "",
+18 -35
View File
@@ -3,7 +3,8 @@
set -e -o pipefail set -e -o pipefail
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}" WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
SERVER_DIR="$WORKSPACE_DIR/worker"
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
ENV_PATH="$WORKSPACE_DIR/worker-env" ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
@@ -21,23 +22,24 @@ function echo_var(){
echo "$1: ${!1}" echo "$1: ${!1}"
} }
# Updated validation - BACKEND no longer required, but MODEL_LOG still is [ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1
[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1 [ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1
[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1 [ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && echo "For comfyui backends, COMFY_MODEL must be set!" && exit 1
echo "start_server.sh - SDK Worker Version"
echo "start_server.sh"
date date
echo_var BACKEND
echo_var REPORT_ADDR echo_var REPORT_ADDR
echo_var WORKER_PORT echo_var WORKER_PORT
echo_var WORKSPACE_DIR echo_var WORKSPACE_DIR
echo_var SERVER_DIR
echo_var ENV_PATH echo_var ENV_PATH
echo_var DEBUG_LOG echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
echo_var MODEL_SERVER_URL
echo_var PYWORKER_REPO
echo_var PYWORKER_REF
# Populate /etc/environment with quoted values # Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then if ! grep -q "VAST" /etc/environment; then
@@ -56,32 +58,16 @@ then
source ~/.local/bin/env source ~/.local/bin/env
fi fi
if [[ ! -d $SERVER_DIR ]]; then # Fork testing
echo "Cloning worker repository..." [[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
git clone --depth=1 "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
fi
if [[ -n ${PYWORKER_REF:-} ]]; then if [[ -n ${PYWORKER_REF:-} ]]; then
echo "Checking out ref: $PYWORKER_REF" (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
(
cd "$SERVER_DIR"
git fetch --depth=1 origin "$PYWORKER_REF"
git checkout "$PYWORKER_REF"
)
fi fi
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10 uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate" source "$ENV_PATH/bin/activate"
# Install vast-sdk from server-side-sdk branch uv pip install -r "${SERVER_DIR}/requirements.txt"
echo "Installing vast-sdk from GitHub (server-side-sdk branch)..."
uv pip install "git+https://github.com/vast-ai/vast-sdk.git@server-side-sdk"
# Install requirements from worker repo if they exist
if [ -f "${SERVER_DIR}/requirements.txt" ]; then
echo "Installing additional dependencies from requirements.txt..."
uv pip install -r "${SERVER_DIR}/requirements.txt"
fi
touch ~/.no_auto_tmux touch ~/.no_auto_tmux
else else
@@ -91,12 +77,7 @@ else
echo "venv: $VIRTUAL_ENV" echo "venv: $VIRTUAL_ENV"
fi fi
# Check that worker.py exists [ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
if [ ! -f "$SERVER_DIR/worker.py" ]; then
echo "ERROR: worker.py not found in $SERVER_DIR"
echo "Please ensure your PYWORKER_REPO contains a worker.py file"
exit 1
fi
if [ "$USE_SSL" = true ]; then if [ "$USE_SSL" = true ]; then
@@ -134,6 +115,9 @@ EOF
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi fi
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR" cd "$SERVER_DIR"
@@ -144,6 +128,5 @@ echo "launching PyWorker server"
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only # from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
# Launch the SDK-based worker instead of the old backend system (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
(python3 worker.py |& tee -a "$PYWORKER_LOG") & echo "launching PyWorker server done"
echo "launching PyWorker server done"
+32 -26
View File
@@ -33,33 +33,39 @@ log = logging.getLogger(__file__)
async def generate_client_response( async def generate_client_response(
client_request: web.Request, model_response: ClientResponse client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
# Check if the response is actually streaming based on response headers/content-type match model_response.status:
is_streaming_response = ( case 200:
model_response.content_type == "text/event-stream" log.debug("SUCCESS")
or model_response.content_type == "application/x-ndjson" # Check if the response is actually streaming based on response headers/content-type
or model_response.headers.get("Transfer-Encoding") == "chunked" is_streaming_response = (
or "stream" in model_response.content_type.lower() model_response.content_type == "text/event-stream"
) or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response: if is_streaming_response:
log.debug("Detected streaming response...") log.debug("Detected streaming response...")
res = web.StreamResponse() res = web.StreamResponse()
res.content_type = model_response.content_type res.content_type = model_response.content_type
await res.prepare(client_request) await res.prepare(client_request)
async for chunk in model_response.content: async for chunk in model_response.content:
await res.write(chunk) await res.write(chunk)
await res.write_eof() await res.write_eof()
log.debug("Done streaming response") log.debug("Done streaming response")
return res return res
else: else:
log.debug("Detected non-streaming response...") log.debug("Detected non-streaming response...")
content = await model_response.read() content = await model_response.read()
return web.Response( return web.Response(
body=content, body=content,
status=model_response.status, status=model_response.status,
content_type=model_response.content_type content_type=model_response.content_type
) )
case code:
log.debug(f"Model responded with error {code}")
return web.Response(status=code)
@dataclasses.dataclass @dataclasses.dataclass
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]): class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):