Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| import logging | |
| import os | |
| import time | |
| import traceback | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageFilter | |
| from diffusers import ControlNetModel, DPMSolverMultistepScheduler | |
| from diffusers import StableDiffusionXLControlNetInpaintPipeline | |
| from diffusers import StableDiffusionXLInpaintPipeline | |
| from transformers import AutoImageProcessor, AutoModelForDepthEstimation | |
| from transformers import DPTImageProcessor, DPTForDepthEstimation | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class InpaintingConfig: | |
| """Configuration for inpainting operations.""" | |
| # ControlNet settings | |
| controlnet_conditioning_scale: float = 0.7 | |
| conditioning_type: str = "canny" # "canny" or "depth" | |
| # Canny edge detection parameters | |
| canny_low_threshold: int = 100 | |
| canny_high_threshold: int = 200 | |
| # Mask settings | |
| feather_radius: int = 8 | |
| min_mask_coverage: float = 0.01 | |
| max_mask_coverage: float = 0.95 | |
| # Generation settings | |
| num_inference_steps: int = 25 | |
| guidance_scale: float = 7.5 | |
| strength: float = 1.0 # Inpainting strength (0.0-1.0), 1.0 = full repaint | |
| preview_steps: int = 15 | |
| preview_guidance_scale: float = 8.0 | |
| # Quality settings | |
| enable_auto_optimization: bool = True | |
| max_optimization_retries: int = 3 | |
| min_quality_score: float = 70.0 | |
| # Memory settings | |
| enable_vae_tiling: bool = True | |
| enable_attention_slicing: bool = True | |
| max_resolution: int = 1024 | |
| class InpaintingResult: | |
| """Result container for inpainting operations.""" | |
| success: bool | |
| result_image: Optional[Image.Image] = None | |
| preview_image: Optional[Image.Image] = None | |
| control_image: Optional[Image.Image] = None | |
| blended_image: Optional[Image.Image] = None | |
| quality_score: float = 0.0 | |
| quality_details: Dict[str, Any] = field(default_factory=dict) | |
| generation_time: float = 0.0 | |
| retries: int = 0 | |
| error_message: str = "" | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| class InpaintingModule: | |
| """ | |
| ControlNet-based Inpainting Module for SceneWeaver. | |
| Implements StableDiffusionXLControlNetInpaintPipeline with support for | |
| Canny edge and depth map conditioning. Features two-stage generation | |
| (preview + full quality) and automatic quality optimization. | |
| Attributes: | |
| device: Computation device (cuda/mps/cpu) | |
| config: InpaintingConfig instance | |
| is_initialized: Whether pipeline is loaded | |
| Example: | |
| >>> module = InpaintingModule(device="cuda") | |
| >>> module.load_inpainting_pipeline(progress_callback=my_callback) | |
| >>> result = module.execute_inpainting( | |
| ... image=my_image, | |
| ... mask=my_mask, | |
| ... prompt="a beautiful garden" | |
| ... ) | |
| """ | |
| # Model identifiers | |
| CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0" | |
| CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0" | |
| DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf" | |
| DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas" | |
| BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" | |
| def __init__( | |
| self, | |
| device: str = "auto", | |
| config: Optional[InpaintingConfig] = None | |
| ): | |
| """ | |
| Initialize the InpaintingModule. | |
| Parameters | |
| ---------- | |
| device : str, optional | |
| Computation device. "auto" for automatic detection. | |
| config : InpaintingConfig, optional | |
| Configuration object. Uses defaults if not provided. | |
| """ | |
| self.device = self._setup_device(device) | |
| self.config = config or InpaintingConfig() | |
| # Pipeline instances (lazy loaded) | |
| self._inpaint_pipeline = None | |
| self._controlnet_canny = None | |
| self._controlnet_depth = None | |
| self._depth_estimator = None | |
| self._depth_processor = None | |
| # State tracking | |
| self.is_initialized = False | |
| self._current_conditioning_type = None | |
| self._last_seed = None | |
| self._cached_latents = None | |
| self._use_controlnet = True # Track if ControlNet is available | |
| # Reference to model manager (set by SceneWeaverCore) | |
| self._model_manager = None | |
| logger.info(f"InpaintingModule initialized on {self.device}") | |
| def _setup_device(self, device: str) -> str: | |
| """ | |
| Setup computation device. | |
| Parameters | |
| ---------- | |
| device : str | |
| Device specification or "auto" | |
| Returns | |
| ------- | |
| str | |
| Resolved device name | |
| """ | |
| if device == "auto": | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| return device | |
| def set_model_manager(self, manager: Any) -> None: | |
| """ | |
| Set reference to ModelManager for coordinated model lifecycle. | |
| Parameters | |
| ---------- | |
| manager : ModelManager | |
| The global model manager instance | |
| """ | |
| self._model_manager = manager | |
| logger.info("ModelManager reference set for InpaintingModule") | |
| def _memory_cleanup(self, aggressive: bool = False) -> None: | |
| """ | |
| Perform memory cleanup. | |
| Parameters | |
| ---------- | |
| aggressive : bool | |
| If True, perform multiple GC rounds and sync CUDA | |
| """ | |
| rounds = 5 if aggressive else 2 | |
| for _ in range(rounds): | |
| gc.collect() | |
| # On Hugging Face Spaces, avoid CUDA operations in main process | |
| # CUDA operations must only happen within @spaces.GPU decorated functions | |
| is_spaces = os.getenv('SPACE_ID') is not None | |
| if not is_spaces and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if aggressive: | |
| torch.cuda.ipc_collect() | |
| torch.cuda.synchronize() | |
| logger.debug(f"Memory cleanup completed (aggressive={aggressive}, spaces={is_spaces})") | |
| def _check_memory_status(self) -> Dict[str, float]: | |
| """ | |
| Check current GPU memory status. | |
| Returns | |
| ------- | |
| dict | |
| Memory statistics including allocated, total, and usage ratio | |
| """ | |
| # On Spaces, skip CUDA checks in main process | |
| is_spaces = os.getenv('SPACE_ID') is not None | |
| if is_spaces or not torch.cuda.is_available(): | |
| return {"available": True, "usage_ratio": 0.0} | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| usage_ratio = allocated / total | |
| return { | |
| "allocated_gb": round(allocated, 2), | |
| "total_gb": round(total, 2), | |
| "free_gb": round(total - allocated, 2), | |
| "usage_ratio": round(usage_ratio, 3), | |
| "available": usage_ratio < 0.9 | |
| } | |
| def load_inpainting_pipeline( | |
| self, | |
| conditioning_type: str = "canny", | |
| progress_callback: Optional[Callable[[str, int], None]] = None | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Load the ControlNet inpainting pipeline. | |
| Implements mutual exclusion with background generation pipeline. | |
| Only one pipeline can be loaded at a time. | |
| Parameters | |
| ---------- | |
| conditioning_type : str | |
| Type of ControlNet conditioning: "canny" or "depth" | |
| progress_callback : callable, optional | |
| Function(message, percentage) for progress updates | |
| Returns | |
| ------- | |
| tuple | |
| (success: bool, error_message: str) | |
| """ | |
| if self.is_initialized and self._current_conditioning_type == conditioning_type: | |
| logger.info(f"Inpainting pipeline already loaded with {conditioning_type}") | |
| return True, "" | |
| logger.info(f"Loading inpainting pipeline with {conditioning_type} conditioning...") | |
| try: | |
| self._memory_cleanup(aggressive=True) | |
| if progress_callback: | |
| progress_callback("Preparing to load inpainting models...", 5) | |
| # Unload existing pipeline if different conditioning type | |
| if self._inpaint_pipeline is not None: | |
| self._unload_pipeline() | |
| # Use ControlNet inpainting by default | |
| use_controlnet_inpaint = True | |
| logger.info("Using StableDiffusionXLControlNetInpaintPipeline") | |
| if progress_callback: | |
| progress_callback("Loading ControlNet model...", 20) | |
| # Load appropriate ControlNet | |
| dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| controlnet = None | |
| if use_controlnet_inpaint: | |
| if conditioning_type == "canny": | |
| controlnet = ControlNetModel.from_pretrained( | |
| self.CONTROLNET_CANNY_MODEL, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| self._controlnet_canny = controlnet | |
| logger.info("Loaded ControlNet Canny model") | |
| elif conditioning_type == "depth": | |
| controlnet = ControlNetModel.from_pretrained( | |
| self.CONTROLNET_DEPTH_MODEL, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| self._controlnet_depth = controlnet | |
| # Load depth estimator | |
| if progress_callback: | |
| progress_callback("Loading depth estimation model...", 35) | |
| self._load_depth_estimator() | |
| logger.info("Loaded ControlNet Depth model") | |
| else: | |
| raise ValueError(f"Unknown conditioning type: {conditioning_type}") | |
| else: | |
| # Skip ControlNet loading for fallback mode | |
| logger.info(f"Skipping ControlNet loading (fallback mode)") | |
| if progress_callback: | |
| progress_callback("Loading SDXL Inpainting pipeline...", 50) | |
| # Load the inpainting pipeline | |
| if use_controlnet_inpaint and controlnet is not None: | |
| self._inpaint_pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( | |
| self.BASE_MODEL, | |
| controlnet=controlnet, | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| variant="fp16" if dtype == torch.float16 else None | |
| ) | |
| else: | |
| # Fallback: Use dedicated inpainting model without ControlNet | |
| self._inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| variant="fp16" if dtype == torch.float16 else None | |
| ) | |
| self._use_controlnet = False | |
| # Track ControlNet usage | |
| self._use_controlnet = use_controlnet_inpaint and controlnet is not None | |
| if progress_callback: | |
| progress_callback("Configuring scheduler...", 70) | |
| # Configure scheduler for faster generation | |
| self._inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
| self._inpaint_pipeline.scheduler.config | |
| ) | |
| # Move to device | |
| self._inpaint_pipeline = self._inpaint_pipeline.to(self.device) | |
| if progress_callback: | |
| progress_callback("Applying optimizations...", 85) | |
| # Apply memory optimizations | |
| self._apply_pipeline_optimizations() | |
| # Set eval mode | |
| self._inpaint_pipeline.unet.eval() | |
| if hasattr(self._inpaint_pipeline, 'vae'): | |
| self._inpaint_pipeline.vae.eval() | |
| self.is_initialized = True | |
| self._current_conditioning_type = conditioning_type if self._use_controlnet else "none" | |
| if progress_callback: | |
| progress_callback("Inpainting pipeline ready!", 100) | |
| # Log memory status | |
| mem_status = self._check_memory_status() | |
| logger.info(f"Pipeline loaded. GPU memory: {mem_status.get('allocated_gb', 0):.1f}GB used") | |
| return True, "" | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"Failed to load inpainting pipeline: {error_msg}") | |
| traceback.print_exc() | |
| self._unload_pipeline() | |
| return False, error_msg | |
| def _load_depth_estimator(self) -> None: | |
| """ | |
| Load depth estimation model with fallback strategy. | |
| Tries Depth-Anything first, falls back to MiDaS if unavailable. | |
| """ | |
| try: | |
| logger.info(f"Attempting to load depth model: {self.DEPTH_MODEL_PRIMARY}") | |
| self._depth_processor = AutoImageProcessor.from_pretrained( | |
| self.DEPTH_MODEL_PRIMARY | |
| ) | |
| self._depth_estimator = AutoModelForDepthEstimation.from_pretrained( | |
| self.DEPTH_MODEL_PRIMARY, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| self._depth_estimator.to(self.device) | |
| self._depth_estimator.eval() | |
| logger.info("Successfully loaded Depth-Anything model") | |
| except Exception as e: | |
| logger.warning(f"Primary depth model failed: {e}, trying fallback...") | |
| try: | |
| self._depth_processor = DPTImageProcessor.from_pretrained( | |
| self.DEPTH_MODEL_FALLBACK | |
| ) | |
| self._depth_estimator = DPTForDepthEstimation.from_pretrained( | |
| self.DEPTH_MODEL_FALLBACK, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| self._depth_estimator.to(self.device) | |
| self._depth_estimator.eval() | |
| logger.info("Successfully loaded MiDaS fallback model") | |
| except Exception as fallback_e: | |
| logger.error(f"Fallback depth model also failed: {fallback_e}") | |
| raise RuntimeError("Unable to load any depth estimation model") | |
| def _apply_pipeline_optimizations(self) -> None: | |
| """Apply memory and performance optimizations to the pipeline.""" | |
| if self._inpaint_pipeline is None: | |
| return | |
| # Try xformers first | |
| try: | |
| self._inpaint_pipeline.enable_xformers_memory_efficient_attention() | |
| logger.info("Enabled xformers memory efficient attention") | |
| except Exception: | |
| try: | |
| self._inpaint_pipeline.enable_attention_slicing() | |
| logger.info("Enabled attention slicing") | |
| except Exception: | |
| logger.warning("No attention optimization available") | |
| # VAE optimizations | |
| if self.config.enable_vae_tiling: | |
| if hasattr(self._inpaint_pipeline, 'enable_vae_tiling'): | |
| self._inpaint_pipeline.enable_vae_tiling() | |
| logger.debug("Enabled VAE tiling") | |
| if hasattr(self._inpaint_pipeline, 'enable_vae_slicing'): | |
| self._inpaint_pipeline.enable_vae_slicing() | |
| logger.debug("Enabled VAE slicing") | |
| def _unload_pipeline(self) -> None: | |
| """Unload the inpainting pipeline and free memory.""" | |
| logger.info("Unloading inpainting pipeline...") | |
| if self._inpaint_pipeline is not None: | |
| del self._inpaint_pipeline | |
| self._inpaint_pipeline = None | |
| if self._controlnet_canny is not None: | |
| del self._controlnet_canny | |
| self._controlnet_canny = None | |
| if self._controlnet_depth is not None: | |
| del self._controlnet_depth | |
| self._controlnet_depth = None | |
| if self._depth_estimator is not None: | |
| del self._depth_estimator | |
| self._depth_estimator = None | |
| if self._depth_processor is not None: | |
| del self._depth_processor | |
| self._depth_processor = None | |
| self.is_initialized = False | |
| self._current_conditioning_type = None | |
| self._cached_latents = None | |
| self._memory_cleanup(aggressive=True) | |
| logger.info("Inpainting pipeline unloaded") | |
| def prepare_control_image( | |
| self, | |
| image: Image.Image, | |
| mode: str = "canny", | |
| mask: Optional[Image.Image] = None, | |
| preserve_structure: bool = False | |
| ) -> Image.Image: | |
| """ | |
| Generate ControlNet conditioning image. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Input image | |
| mode : str | |
| Conditioning mode: "canny" or "depth" | |
| mask : PIL.Image, optional | |
| If provided, can suppress edges in masked region (when preserve_structure=False). | |
| preserve_structure : bool | |
| If True, keep edges in masked region (for color change tasks). | |
| If False, suppress edges in masked region (for replacement/removal tasks). | |
| Returns | |
| ------- | |
| PIL.Image | |
| Generated control image (edges or depth map) | |
| """ | |
| logger.info(f"Preparing control image with mode: {mode}, preserve_structure: {preserve_structure}") | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img_array = np.array(image) | |
| if mode == "canny": | |
| canny_image = self._generate_canny_edges(img_array) | |
| # Mask-aware processing: suppress edges in masked region ONLY if not preserving structure | |
| if mask is not None and not preserve_structure: | |
| canny_array = np.array(canny_image) | |
| mask_array = np.array(mask.convert('L')) | |
| # In masked region, completely suppress Canny edges | |
| # This allows complete replacement/removal of the object | |
| mask_region = mask_array > 128 # White = masked area | |
| canny_array[mask_region] = 0 | |
| canny_image = Image.fromarray(canny_array) | |
| logger.info("Suppressed edges in masked region for replacement/removal") | |
| elif preserve_structure: | |
| logger.info("Preserving edges in masked region for color change") | |
| return canny_image | |
| elif mode == "depth": | |
| return self._generate_depth_map(image) | |
| else: | |
| raise ValueError(f"Unknown control mode: {mode}") | |
| def _generate_canny_edges(self, img_array: np.ndarray) -> Image.Image: | |
| """ | |
| Generate Canny edge detection image. | |
| Parameters | |
| ---------- | |
| img_array : np.ndarray | |
| Input image as RGB numpy array | |
| Returns | |
| ------- | |
| PIL.Image | |
| Edge detection result as grayscale image | |
| """ | |
| # Convert to grayscale | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| # Apply Gaussian blur to reduce noise | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 1.4) | |
| # Canny edge detection | |
| edges = cv2.Canny( | |
| blurred, | |
| self.config.canny_low_threshold, | |
| self.config.canny_high_threshold | |
| ) | |
| # Convert to 3-channel for ControlNet | |
| edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) | |
| logger.debug(f"Generated Canny edges with thresholds " | |
| f"{self.config.canny_low_threshold}/{self.config.canny_high_threshold}") | |
| return Image.fromarray(edges_3ch) | |
| def _generate_depth_map(self, image: Image.Image) -> Image.Image: | |
| """ | |
| Generate depth map using depth estimation model. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Input RGB image | |
| Returns | |
| ------- | |
| PIL.Image | |
| Depth map as grayscale image | |
| """ | |
| if self._depth_estimator is None or self._depth_processor is None: | |
| raise RuntimeError("Depth estimator not loaded") | |
| # Preprocess | |
| inputs = self._depth_processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = self._depth_estimator(**inputs) | |
| predicted_depth = outputs.predicted_depth | |
| # Interpolate to original size | |
| prediction = torch.nn.functional.interpolate( | |
| predicted_depth.unsqueeze(1), | |
| size=image.size[::-1], # (H, W) | |
| mode="bicubic", | |
| align_corners=False | |
| ) | |
| # Normalize to 0-255 | |
| depth_array = prediction.squeeze().cpu().numpy() | |
| depth_min = depth_array.min() | |
| depth_max = depth_array.max() | |
| if depth_max - depth_min > 0: | |
| depth_normalized = ((depth_array - depth_min) / (depth_max - depth_min) * 255) | |
| else: | |
| depth_normalized = np.zeros_like(depth_array) | |
| depth_normalized = depth_normalized.astype(np.uint8) | |
| # Convert to 3-channel for ControlNet | |
| depth_3ch = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB) | |
| logger.debug(f"Generated depth map, range: {depth_min:.2f} - {depth_max:.2f}") | |
| return Image.fromarray(depth_3ch) | |
| def prepare_mask( | |
| self, | |
| mask: Image.Image, | |
| target_size: Tuple[int, int], | |
| feather_radius: Optional[int] = None | |
| ) -> Tuple[Image.Image, Dict[str, Any]]: | |
| """ | |
| Prepare and validate mask for inpainting. | |
| Parameters | |
| ---------- | |
| mask : PIL.Image | |
| Input mask (white = inpaint area) | |
| target_size : tuple | |
| Target (width, height) to match input image | |
| feather_radius : int, optional | |
| Feathering radius in pixels. Uses config default if None. | |
| Returns | |
| ------- | |
| tuple | |
| (processed_mask, validation_info) | |
| Raises | |
| ------ | |
| ValueError | |
| If mask coverage is outside acceptable range | |
| """ | |
| feather = feather_radius if feather_radius is not None else self.config.feather_radius | |
| # Convert to grayscale | |
| if mask.mode != 'L': | |
| mask = mask.convert('L') | |
| # Resize to match target | |
| if mask.size != target_size: | |
| mask = mask.resize(target_size, Image.LANCZOS) | |
| # Convert to array for processing | |
| mask_array = np.array(mask) | |
| # Calculate coverage | |
| total_pixels = mask_array.size | |
| white_pixels = np.count_nonzero(mask_array > 127) | |
| coverage = white_pixels / total_pixels | |
| validation_info = { | |
| "coverage": coverage, | |
| "white_pixels": white_pixels, | |
| "total_pixels": total_pixels, | |
| "feather_radius": feather, | |
| "valid": True, | |
| "warning": "" | |
| } | |
| # Validate coverage | |
| if coverage < self.config.min_mask_coverage: | |
| validation_info["valid"] = False | |
| validation_info["warning"] = ( | |
| f"Mask coverage too low ({coverage:.1%}). " | |
| f"Please select a larger area to inpaint." | |
| ) | |
| logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.config.min_mask_coverage:.1%}") | |
| elif coverage > self.config.max_mask_coverage: | |
| validation_info["valid"] = False | |
| validation_info["warning"] = ( | |
| f"Mask coverage too high ({coverage:.1%}). " | |
| f"Consider using background generation instead." | |
| ) | |
| logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.config.max_mask_coverage:.1%}") | |
| # Apply feathering | |
| if feather > 0: | |
| mask_array = cv2.GaussianBlur( | |
| mask_array, | |
| (feather * 2 + 1, feather * 2 + 1), | |
| feather / 2 | |
| ) | |
| logger.debug(f"Applied {feather}px feathering to mask") | |
| processed_mask = Image.fromarray(mask_array, mode='L') | |
| return processed_mask, validation_info | |
| def enhance_prompt_for_inpainting( | |
| self, | |
| prompt: str, | |
| image: Image.Image, | |
| mask: Image.Image | |
| ) -> Tuple[str, str]: | |
| """ | |
| Enhance prompt based on non-masked region analysis. | |
| Analyzes the surrounding context to generate appropriate | |
| lighting and color descriptors. | |
| Parameters | |
| ---------- | |
| prompt : str | |
| User-provided prompt | |
| image : PIL.Image | |
| Original image | |
| mask : PIL.Image | |
| Inpainting mask | |
| Returns | |
| ------- | |
| tuple | |
| (enhanced_prompt, negative_prompt) | |
| """ | |
| logger.info("Enhancing prompt for inpainting context...") | |
| # Convert to arrays | |
| img_array = np.array(image.convert('RGB')) | |
| mask_array = np.array(mask.convert('L')) | |
| # Analyze non-masked regions | |
| non_masked = mask_array < 127 | |
| if not np.any(non_masked): | |
| # No context available | |
| enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic" | |
| negative_prompt = self._get_inpainting_negative_prompt() | |
| return enhanced_prompt, negative_prompt | |
| # Extract context pixels | |
| context_pixels = img_array[non_masked] | |
| # Convert to Lab for analysis | |
| context_lab = cv2.cvtColor( | |
| context_pixels.reshape(-1, 1, 3), | |
| cv2.COLOR_RGB2LAB | |
| ).reshape(-1, 3) | |
| # Use robust statistics (median) to avoid outlier influence | |
| median_l = np.median(context_lab[:, 0]) | |
| median_a = np.median(context_lab[:, 1]) | |
| median_b = np.median(context_lab[:, 2]) | |
| # Analyze lighting conditions | |
| lighting_descriptors = [] | |
| if median_l > 170: | |
| lighting_descriptors.append("bright") | |
| elif median_l > 130: | |
| lighting_descriptors.append("well-lit") | |
| elif median_l > 80: | |
| lighting_descriptors.append("moderate lighting") | |
| else: | |
| lighting_descriptors.append("dim lighting") | |
| # Analyze color temperature (b channel: blue(-) to yellow(+)) | |
| if median_b > 140: | |
| lighting_descriptors.append("warm golden tones") | |
| elif median_b > 120: | |
| lighting_descriptors.append("warm afternoon light") | |
| elif median_b < 110: | |
| lighting_descriptors.append("cool neutral tones") | |
| # Calculate saturation from context | |
| hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV) | |
| median_saturation = np.median(hsv[:, :, 1]) | |
| if median_saturation > 150: | |
| lighting_descriptors.append("vibrant colors") | |
| elif median_saturation < 80: | |
| lighting_descriptors.append("subtle muted colors") | |
| # Build enhanced prompt | |
| lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else "" | |
| quality_suffix = "high quality, detailed, photorealistic, seamless integration" | |
| if lighting_desc: | |
| enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}" | |
| else: | |
| enhanced_prompt = f"{prompt}, {quality_suffix}" | |
| negative_prompt = self._get_inpainting_negative_prompt() | |
| logger.info(f"Enhanced prompt with context: {lighting_desc}") | |
| return enhanced_prompt, negative_prompt | |
| def _get_inpainting_negative_prompt(self) -> str: | |
| """Get standard negative prompt for inpainting.""" | |
| return ( | |
| "inconsistent lighting, wrong perspective, mismatched colors, " | |
| "visible seams, blending artifacts, color bleeding, " | |
| "blurry, low quality, distorted, deformed, " | |
| "harsh edges, unnatural transition" | |
| ) | |
| def execute_inpainting( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| prompt: str, | |
| preview_only: bool = False, | |
| seed: Optional[int] = None, | |
| progress_callback: Optional[Callable[[str, int], None]] = None, | |
| **kwargs | |
| ) -> InpaintingResult: | |
| """ | |
| Execute the inpainting operation. | |
| Implements two-stage generation: fast preview followed by | |
| full quality generation if requested. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Original image to inpaint | |
| mask : PIL.Image | |
| Inpainting mask (white = area to regenerate) | |
| prompt : str | |
| Text description of desired content | |
| preview_only : bool | |
| If True, only generate preview (faster) | |
| seed : int, optional | |
| Random seed for reproducibility | |
| progress_callback : callable, optional | |
| Progress update function(message, percentage) | |
| **kwargs | |
| Additional parameters: | |
| - controlnet_conditioning_scale: float | |
| - feather_radius: int | |
| - num_inference_steps: int | |
| - guidance_scale: float | |
| Returns | |
| ------- | |
| InpaintingResult | |
| Result container with generated images and metadata | |
| """ | |
| start_time = time.time() | |
| if not self.is_initialized: | |
| return InpaintingResult( | |
| success=False, | |
| error_message="Inpainting pipeline not initialized. Call load_inpainting_pipeline() first." | |
| ) | |
| logger.info(f"Starting inpainting: prompt='{prompt[:50]}...', preview_only={preview_only}") | |
| try: | |
| # Update config with kwargs | |
| conditioning_scale = kwargs.get( | |
| 'controlnet_conditioning_scale', | |
| self.config.controlnet_conditioning_scale | |
| ) | |
| feather_radius = kwargs.get('feather_radius', self.config.feather_radius) | |
| strength = kwargs.get('strength', self.config.strength) | |
| preserve_structure = kwargs.get('preserve_structure_in_mask', False) | |
| if progress_callback: | |
| progress_callback("Preparing images...", 5) | |
| # Prepare image | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Ensure dimensions are multiple of 8 | |
| width, height = image.size | |
| new_width = (width // 8) * 8 | |
| new_height = (height // 8) * 8 | |
| if new_width != width or new_height != height: | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # Check and potentially reduce resolution for memory | |
| max_res = self.config.max_resolution | |
| if max(new_width, new_height) > max_res: | |
| scale = max_res / max(new_width, new_height) | |
| new_width = int(new_width * scale) // 8 * 8 | |
| new_height = int(new_height * scale) // 8 * 8 | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| logger.info(f"Reduced resolution to {new_width}x{new_height} for memory") | |
| # Prepare mask | |
| if progress_callback: | |
| progress_callback("Processing mask...", 10) | |
| processed_mask, mask_info = self.prepare_mask( | |
| mask, | |
| (new_width, new_height), | |
| feather_radius | |
| ) | |
| if not mask_info["valid"]: | |
| return InpaintingResult( | |
| success=False, | |
| error_message=mask_info["warning"] | |
| ) | |
| # Generate control image | |
| if progress_callback: | |
| progress_callback("Generating control image...", 20) | |
| control_image = self.prepare_control_image( | |
| image, | |
| self._current_conditioning_type, | |
| mask=processed_mask, | |
| preserve_structure=preserve_structure # True for color change, False for replacement/removal | |
| ) | |
| # Conditional prompt enhancement based on template | |
| # Check if we should enhance the prompt or use it directly | |
| should_enhance = kwargs.get('enhance_prompt', False) # Default: no enhancement | |
| if should_enhance: | |
| if progress_callback: | |
| progress_callback("Enhancing prompt...", 25) | |
| enhanced_prompt, negative_prompt = self.enhance_prompt_for_inpainting( | |
| prompt, image, processed_mask | |
| ) | |
| logger.info(f"Prompt enhanced with OpenCLIP context") | |
| else: | |
| # Use prompt directly without enhancement | |
| enhanced_prompt = prompt | |
| negative_prompt = self._get_inpainting_negative_prompt() | |
| logger.info("Prompt enhancement disabled for this template") | |
| # Setup generator for reproducibility | |
| if seed is None: | |
| seed = int(time.time() * 1000) % (2**32) | |
| self._last_seed = seed | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| # Check if running on Hugging Face Spaces | |
| is_spaces = os.getenv('SPACE_ID') is not None | |
| # Stage 1: Preview generation | |
| # On Spaces, skip preview to save time (300s hard limit) | |
| preview_result = None | |
| if preview_only or not is_spaces: | |
| if progress_callback: | |
| progress_callback("Generating preview...", 30) | |
| # Optimize preview steps for Hugging Face Spaces | |
| preview_steps = self.config.preview_steps | |
| if is_spaces: | |
| # On Spaces, use minimal preview steps | |
| preview_steps = min(preview_steps, 8) | |
| logger.debug(f"Spaces environment - using {preview_steps} preview steps") | |
| preview_result = self._generate_inpaint( | |
| image=image, | |
| mask=processed_mask, | |
| control_image=control_image, | |
| prompt=enhanced_prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=preview_steps, | |
| guidance_scale=self.config.preview_guidance_scale, | |
| controlnet_conditioning_scale=conditioning_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| else: | |
| logger.debug("Spaces environment - skipping preview to fit 300s limit") | |
| if preview_only: | |
| generation_time = time.time() - start_time | |
| return InpaintingResult( | |
| success=True, | |
| preview_image=preview_result, | |
| control_image=control_image, | |
| generation_time=generation_time, | |
| metadata={ | |
| "seed": seed, | |
| "prompt": enhanced_prompt, | |
| "conditioning_type": self._current_conditioning_type, | |
| "conditioning_scale": conditioning_scale, | |
| "preview_only": True | |
| } | |
| ) | |
| # Stage 2: Full quality generation | |
| if progress_callback: | |
| progress_callback("Generating full quality...", 60) | |
| # Use same seed for reproducibility | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps) | |
| guidance = kwargs.get('guidance_scale', self.config.guidance_scale) | |
| # Optimize for Hugging Face Spaces ZeroGPU (stateless, 300s hard limit) | |
| if is_spaces: | |
| # ZeroGPU timing breakdown with model caching (actual measurements): | |
| # - Model loading from cache: ~60s (cached models, CPU to GPU transfer) | |
| # - Inference: ~28-29s/step (observed on shared H200) | |
| # - Blending & overhead: ~35s | |
| # - Platform limit: 300s hard limit (Pro tier) | |
| # | |
| # Strategy with unified 10-step approach: | |
| # - Skip preview completely (done above) | |
| # - Use 10 steps for balance of quality and speed | |
| # - Time budget: 60s (load) + 285s (10 steps) + 35s (blend) = 380s | |
| # - Note: Still may timeout, but parameter optimization is more important than step count | |
| # - Quality comes from correct conditioning_scale, not high step count | |
| spaces_max_steps = 10 # Optimized: 10 steps sufficient with proper parameters | |
| if num_steps > spaces_max_steps: | |
| num_steps = spaces_max_steps | |
| logger.debug(f"Spaces deployment: using {num_steps} steps (optimized for parameter quality)") | |
| full_result = self._generate_inpaint( | |
| image=image, | |
| mask=processed_mask, | |
| control_image=control_image, | |
| prompt=enhanced_prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance, | |
| controlnet_conditioning_scale=conditioning_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| if progress_callback: | |
| progress_callback("Blending result...", 90) | |
| # Blend result | |
| blended = self.blend_result(image, full_result, processed_mask) | |
| generation_time = time.time() - start_time | |
| if progress_callback: | |
| progress_callback("Complete!", 100) | |
| return InpaintingResult( | |
| success=True, | |
| result_image=full_result, | |
| preview_image=preview_result, | |
| control_image=control_image, | |
| blended_image=blended, | |
| generation_time=generation_time, | |
| metadata={ | |
| "seed": seed, | |
| "prompt": enhanced_prompt, | |
| "negative_prompt": negative_prompt, | |
| "conditioning_type": self._current_conditioning_type, | |
| "conditioning_scale": conditioning_scale, | |
| "strength": strength, | |
| "preserve_structure": preserve_structure, | |
| "num_inference_steps": num_steps, | |
| "guidance_scale": guidance, | |
| "feather_radius": feather_radius, | |
| "mask_coverage": mask_info["coverage"], | |
| "preview_only": False | |
| } | |
| ) | |
| except torch.cuda.OutOfMemoryError: | |
| logger.error("CUDA out of memory during inpainting") | |
| self._memory_cleanup(aggressive=True) | |
| return InpaintingResult( | |
| success=False, | |
| error_message="GPU memory exhausted. Try reducing image size or closing other applications." | |
| ) | |
| except Exception as e: | |
| logger.error(f"Inpainting failed: {e}") | |
| logger.error(traceback.format_exc()) | |
| return InpaintingResult( | |
| success=False, | |
| error_message=f"Inpainting failed: {str(e)}" | |
| ) | |
| def _generate_inpaint( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| control_image: Image.Image, | |
| prompt: str, | |
| negative_prompt: str, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| controlnet_conditioning_scale: float, | |
| strength: float, | |
| generator: torch.Generator | |
| ) -> Image.Image: | |
| """ | |
| Internal method to run the inpainting pipeline. | |
| Supports both ControlNet and non-ControlNet pipelines. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Original image | |
| mask : PIL.Image | |
| Processed mask | |
| control_image : PIL.Image | |
| ControlNet conditioning image (ignored if ControlNet not available) | |
| prompt : str | |
| Enhanced prompt | |
| negative_prompt : str | |
| Negative prompt | |
| num_inference_steps : int | |
| Number of denoising steps | |
| guidance_scale : float | |
| Classifier-free guidance scale | |
| controlnet_conditioning_scale : float | |
| ControlNet influence strength (ignored if ControlNet not available) | |
| strength : float | |
| Inpainting strength (0.0-1.0). 1.0 = fully repaint masked area. | |
| generator : torch.Generator | |
| Random generator for reproducibility | |
| Returns | |
| ------- | |
| PIL.Image | |
| Generated image | |
| """ | |
| with torch.inference_mode(): | |
| if self._use_controlnet: | |
| # Full ControlNet inpainting pipeline | |
| result = self._inpaint_pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=image, | |
| mask_image=mask, | |
| control_image=control_image, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| else: | |
| # Fallback: Standard SDXL inpainting without ControlNet | |
| result = self._inpaint_pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=image, | |
| mask_image=mask, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| return result.images[0] | |
| def blend_result( | |
| self, | |
| original: Image.Image, | |
| generated: Image.Image, | |
| mask: Image.Image | |
| ) -> Image.Image: | |
| """ | |
| Blend generated content with original image. | |
| Uses linear color space blending for accurate results. | |
| Parameters | |
| ---------- | |
| original : PIL.Image | |
| Original image | |
| generated : PIL.Image | |
| Generated inpainted image | |
| mask : PIL.Image | |
| Blending mask (white = use generated) | |
| Returns | |
| ------- | |
| PIL.Image | |
| Blended result | |
| """ | |
| logger.info("Blending inpainting result...") | |
| # Ensure same size | |
| if generated.size != original.size: | |
| generated = generated.resize(original.size, Image.LANCZOS) | |
| if mask.size != original.size: | |
| mask = mask.resize(original.size, Image.LANCZOS) | |
| # Convert to arrays | |
| orig_array = np.array(original.convert('RGB')).astype(np.float32) | |
| gen_array = np.array(generated.convert('RGB')).astype(np.float32) | |
| mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0 | |
| # sRGB to linear conversion | |
| def srgb_to_linear(img): | |
| img_norm = img / 255.0 | |
| return np.where( | |
| img_norm <= 0.04045, | |
| img_norm / 12.92, | |
| np.power((img_norm + 0.055) / 1.055, 2.4) | |
| ) | |
| def linear_to_srgb(img): | |
| img_clipped = np.clip(img, 0, 1) | |
| return np.where( | |
| img_clipped <= 0.0031308, | |
| 12.92 * img_clipped, | |
| 1.055 * np.power(img_clipped, 1/2.4) - 0.055 | |
| ) | |
| # Convert to linear space | |
| orig_linear = srgb_to_linear(orig_array) | |
| gen_linear = srgb_to_linear(gen_array) | |
| # Alpha blending in linear space | |
| alpha = mask_array[:, :, np.newaxis] | |
| result_linear = gen_linear * alpha + orig_linear * (1 - alpha) | |
| # Convert back to sRGB | |
| result_srgb = linear_to_srgb(result_linear) | |
| result_array = (result_srgb * 255).astype(np.uint8) | |
| logger.debug("Blending completed in linear color space") | |
| return Image.fromarray(result_array) | |
| def execute_with_auto_optimization( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| prompt: str, | |
| quality_checker: Any, | |
| progress_callback: Optional[Callable[[str, int], None]] = None, | |
| **kwargs | |
| ) -> InpaintingResult: | |
| """ | |
| Execute inpainting with automatic quality-based optimization. | |
| Retries with adjusted parameters if quality score is below threshold. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Original image | |
| mask : PIL.Image | |
| Inpainting mask | |
| prompt : str | |
| Text prompt | |
| quality_checker : QualityChecker | |
| Quality assessment instance | |
| progress_callback : callable, optional | |
| Progress update function | |
| **kwargs | |
| Additional inpainting parameters | |
| Returns | |
| ------- | |
| InpaintingResult | |
| Best result achieved (may include retry information) | |
| """ | |
| if not self.config.enable_auto_optimization: | |
| return self.execute_inpainting( | |
| image, mask, prompt, | |
| progress_callback=progress_callback, | |
| **kwargs | |
| ) | |
| best_result = None | |
| best_score = 0.0 | |
| retry_count = 0 | |
| prev_score = 0.0 | |
| # Mutable parameters for optimization | |
| current_feather = kwargs.get('feather_radius', self.config.feather_radius) | |
| current_scale = kwargs.get( | |
| 'controlnet_conditioning_scale', | |
| self.config.controlnet_conditioning_scale | |
| ) | |
| current_guidance = kwargs.get('guidance_scale', self.config.guidance_scale) | |
| current_prompt = prompt | |
| while retry_count <= self.config.max_optimization_retries: | |
| if progress_callback and retry_count > 0: | |
| progress_callback(f"Optimizing (attempt {retry_count + 1})...", 5) | |
| # Execute inpainting | |
| result = self.execute_inpainting( | |
| image, mask, current_prompt, | |
| preview_only=False, | |
| feather_radius=current_feather, | |
| controlnet_conditioning_scale=current_scale, | |
| guidance_scale=current_guidance, | |
| progress_callback=progress_callback if retry_count == 0 else None, | |
| **{k: v for k, v in kwargs.items() | |
| if k not in ['feather_radius', 'controlnet_conditioning_scale', | |
| 'guidance_scale']} | |
| ) | |
| if not result.success: | |
| return result | |
| # Evaluate quality | |
| if result.blended_image is not None: | |
| quality_results = quality_checker.run_all_checks( | |
| foreground=image, | |
| background=result.result_image, | |
| mask=mask, | |
| combined=result.blended_image | |
| ) | |
| quality_score = quality_results.get("overall_score", 0) | |
| else: | |
| quality_score = 50.0 # Default if no blended image | |
| result.quality_score = quality_score | |
| result.quality_details = quality_results if result.blended_image else {} | |
| result.retries = retry_count | |
| logger.info(f"Quality score: {quality_score:.1f} (attempt {retry_count + 1})") | |
| # Track best result | |
| if quality_score > best_score: | |
| best_score = quality_score | |
| best_result = result | |
| # Check if quality is acceptable | |
| if quality_score >= self.config.min_quality_score: | |
| logger.info(f"Quality threshold met: {quality_score:.1f}") | |
| return best_result | |
| # Check for minimal improvement (early termination) | |
| if retry_count > 0 and abs(quality_score - prev_score) < 5.0: | |
| logger.info("Minimal improvement, stopping optimization") | |
| return best_result | |
| prev_score = quality_score | |
| retry_count += 1 | |
| if retry_count > self.config.max_optimization_retries: | |
| break | |
| # Adjust parameters based on quality issues | |
| checks = quality_results.get("checks", {}) | |
| edge_score = checks.get("edge_continuity", {}).get("score", 100) | |
| harmony_score = checks.get("color_harmony", {}).get("score", 100) | |
| if edge_score < 60: | |
| # Edge issues: increase feathering, decrease control strength | |
| current_feather = min(20, current_feather + 3) | |
| current_scale = max(0.5, current_scale - 0.1) | |
| logger.debug(f"Adjusting for edges: feather={current_feather}, scale={current_scale}") | |
| if harmony_score < 60: | |
| # Color harmony issues: emphasize consistency in prompt | |
| if "color consistent" not in current_prompt.lower(): | |
| current_prompt = f"{current_prompt}, color consistent with surroundings, matching lighting" | |
| current_guidance = min(12.0, current_guidance + 1.0) | |
| logger.debug(f"Adjusting for harmony: guidance={current_guidance}") | |
| if edge_score < 60 and harmony_score < 60: | |
| # Both issues: stronger guidance | |
| current_guidance = min(12.0, current_guidance + 1.5) | |
| logger.info(f"Optimization complete. Best score: {best_score:.1f}") | |
| return best_result | |
| def get_status(self) -> Dict[str, Any]: | |
| """ | |
| Get current module status. | |
| Returns | |
| ------- | |
| dict | |
| Status information including initialization state and memory usage | |
| """ | |
| status = { | |
| "initialized": self.is_initialized, | |
| "device": self.device, | |
| "conditioning_type": self._current_conditioning_type, | |
| "last_seed": self._last_seed, | |
| "config": { | |
| "controlnet_conditioning_scale": self.config.controlnet_conditioning_scale, | |
| "feather_radius": self.config.feather_radius, | |
| "num_inference_steps": self.config.num_inference_steps, | |
| "guidance_scale": self.config.guidance_scale | |
| } | |
| } | |
| status["memory"] = self._check_memory_status() | |
| return status |