adding s3 as an option
This commit is contained in:
@@ -12,16 +12,45 @@ from vastai import Serverless
|
||||
|
||||
# ---------------------- Config ----------------------
|
||||
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
|
||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||
ENDPOINT_NAME = "Comfy-Prod"
|
||||
DEFAULT_WIDTH = 512
|
||||
DEFAULT_HEIGHT = 512
|
||||
DEFAULT_STEPS = 20
|
||||
COST = 100 # Fixed cost for ComfyUI requests
|
||||
|
||||
# Optional S3 Configuration (from environment variables)
|
||||
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
|
||||
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
||||
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
|
||||
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_s3_client():
|
||||
"""Create and return an S3 client configured for the S3-compatible endpoint"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
except ImportError:
|
||||
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
|
||||
return None
|
||||
|
||||
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
|
||||
log.error("S3 environment variables not fully configured. Required:")
|
||||
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
|
||||
return None
|
||||
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
config=Config(signature_version="s3v4"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------- API Functions ----------------------
|
||||
async def call_generate(
|
||||
client: Serverless,
|
||||
@@ -70,9 +99,14 @@ async def call_generate_workflow(
|
||||
|
||||
# ---------------------- Demo Class ----------------------
|
||||
class APIDemo:
|
||||
def __init__(self, client: Serverless, endpoint_name: str):
|
||||
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
|
||||
self.client = client
|
||||
self.endpoint_name = endpoint_name
|
||||
self.upload_s3 = upload_s3
|
||||
self.s3_client = get_s3_client() if upload_s3 else None
|
||||
|
||||
if upload_s3 and not self.s3_client:
|
||||
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
|
||||
|
||||
def extract_filename(self, response: dict) -> str | None:
|
||||
"""Extract the generated image filename from ComfyUI response"""
|
||||
@@ -85,10 +119,29 @@ class APIDemo:
|
||||
return None
|
||||
|
||||
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||
"""Fetch and save image locally from the worker"""
|
||||
"""Fetch and save image locally from the worker, optionally upload to S3"""
|
||||
os.makedirs("generated_images", exist_ok=True)
|
||||
return await self._fetch_image(worker_url, filename, local_name)
|
||||
|
||||
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
|
||||
"""Upload a local file to S3 and return the S3 URL"""
|
||||
if not self.s3_client:
|
||||
return None
|
||||
|
||||
try:
|
||||
self.s3_client.upload_file(
|
||||
local_path,
|
||||
S3_BUCKET_NAME,
|
||||
s3_key,
|
||||
ExtraArgs={"ContentType": "image/png"}
|
||||
)
|
||||
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
|
||||
print(f" ☁️ Uploaded to S3: {s3_key}")
|
||||
return s3_url
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload to S3: {e}")
|
||||
return None
|
||||
|
||||
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||
"""Fetch image from worker's /view endpoint and save locally"""
|
||||
if not worker_url:
|
||||
@@ -102,9 +155,16 @@ class APIDemo:
|
||||
async with session.get(url, params=params, ssl=False) as resp:
|
||||
if resp.status == 200:
|
||||
path = f"generated_images/{local_name}"
|
||||
image_data = await resp.read()
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
f.write(image_data)
|
||||
print(f" 💾 Saved: {path}")
|
||||
|
||||
# Upload to S3 if enabled
|
||||
if self.upload_s3 and self.s3_client:
|
||||
s3_key = f"comfyui/{local_name}"
|
||||
self._upload_to_s3(path, s3_key)
|
||||
|
||||
return path
|
||||
return None
|
||||
except Exception:
|
||||
@@ -207,6 +267,8 @@ def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
|
||||
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
|
||||
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
|
||||
p.add_argument("--s3", action="store_true",
|
||||
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
|
||||
return p
|
||||
|
||||
|
||||
@@ -215,10 +277,12 @@ async def main_async():
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
if args.s3:
|
||||
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.endpoint)
|
||||
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
|
||||
|
||||
if args.workflow:
|
||||
await demo.demo_workflow(workflow_file=args.workflow)
|
||||
|
||||
Reference in New Issue
Block a user