SceneWeaver / inpainting_module.py
DawnC's picture
Update inpainting_module.py
af4e7b9 verified
import gc
import logging
import os
import time
import traceback
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import torch
from PIL import Image, ImageFilter
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers import StableDiffusionXLControlNetInpaintPipeline
from diffusers import StableDiffusionXLInpaintPipeline
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from transformers import DPTImageProcessor, DPTForDepthEstimation
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@dataclass
class InpaintingConfig:
"""Configuration for inpainting operations."""
# ControlNet settings
controlnet_conditioning_scale: float = 0.7
conditioning_type: str = "canny" # "canny" or "depth"
# Canny edge detection parameters
canny_low_threshold: int = 100
canny_high_threshold: int = 200
# Mask settings
feather_radius: int = 8
min_mask_coverage: float = 0.01
max_mask_coverage: float = 0.95
# Generation settings
num_inference_steps: int = 25
guidance_scale: float = 7.5
strength: float = 1.0 # Inpainting strength (0.0-1.0), 1.0 = full repaint
preview_steps: int = 15
preview_guidance_scale: float = 8.0
# Quality settings
enable_auto_optimization: bool = True
max_optimization_retries: int = 3
min_quality_score: float = 70.0
# Memory settings
enable_vae_tiling: bool = True
enable_attention_slicing: bool = True
max_resolution: int = 1024
@dataclass
class InpaintingResult:
"""Result container for inpainting operations."""
success: bool
result_image: Optional[Image.Image] = None
preview_image: Optional[Image.Image] = None
control_image: Optional[Image.Image] = None
blended_image: Optional[Image.Image] = None
quality_score: float = 0.0
quality_details: Dict[str, Any] = field(default_factory=dict)
generation_time: float = 0.0
retries: int = 0
error_message: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
class InpaintingModule:
"""
ControlNet-based Inpainting Module for SceneWeaver.
Implements StableDiffusionXLControlNetInpaintPipeline with support for
Canny edge and depth map conditioning. Features two-stage generation
(preview + full quality) and automatic quality optimization.
Attributes:
device: Computation device (cuda/mps/cpu)
config: InpaintingConfig instance
is_initialized: Whether pipeline is loaded
Example:
>>> module = InpaintingModule(device="cuda")
>>> module.load_inpainting_pipeline(progress_callback=my_callback)
>>> result = module.execute_inpainting(
... image=my_image,
... mask=my_mask,
... prompt="a beautiful garden"
... )
"""
# Model identifiers
CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
def __init__(
self,
device: str = "auto",
config: Optional[InpaintingConfig] = None
):
"""
Initialize the InpaintingModule.
Parameters
----------
device : str, optional
Computation device. "auto" for automatic detection.
config : InpaintingConfig, optional
Configuration object. Uses defaults if not provided.
"""
self.device = self._setup_device(device)
self.config = config or InpaintingConfig()
# Pipeline instances (lazy loaded)
self._inpaint_pipeline = None
self._controlnet_canny = None
self._controlnet_depth = None
self._depth_estimator = None
self._depth_processor = None
# State tracking
self.is_initialized = False
self._current_conditioning_type = None
self._last_seed = None
self._cached_latents = None
self._use_controlnet = True # Track if ControlNet is available
# Reference to model manager (set by SceneWeaverCore)
self._model_manager = None
logger.info(f"InpaintingModule initialized on {self.device}")
def _setup_device(self, device: str) -> str:
"""
Setup computation device.
Parameters
----------
device : str
Device specification or "auto"
Returns
-------
str
Resolved device name
"""
if device == "auto":
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
return "cpu"
return device
def set_model_manager(self, manager: Any) -> None:
"""
Set reference to ModelManager for coordinated model lifecycle.
Parameters
----------
manager : ModelManager
The global model manager instance
"""
self._model_manager = manager
logger.info("ModelManager reference set for InpaintingModule")
def _memory_cleanup(self, aggressive: bool = False) -> None:
"""
Perform memory cleanup.
Parameters
----------
aggressive : bool
If True, perform multiple GC rounds and sync CUDA
"""
rounds = 5 if aggressive else 2
for _ in range(rounds):
gc.collect()
# On Hugging Face Spaces, avoid CUDA operations in main process
# CUDA operations must only happen within @spaces.GPU decorated functions
is_spaces = os.getenv('SPACE_ID') is not None
if not is_spaces and torch.cuda.is_available():
torch.cuda.empty_cache()
if aggressive:
torch.cuda.ipc_collect()
torch.cuda.synchronize()
logger.debug(f"Memory cleanup completed (aggressive={aggressive}, spaces={is_spaces})")
def _check_memory_status(self) -> Dict[str, float]:
"""
Check current GPU memory status.
Returns
-------
dict
Memory statistics including allocated, total, and usage ratio
"""
# On Spaces, skip CUDA checks in main process
is_spaces = os.getenv('SPACE_ID') is not None
if is_spaces or not torch.cuda.is_available():
return {"available": True, "usage_ratio": 0.0}
allocated = torch.cuda.memory_allocated() / 1024**3
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
usage_ratio = allocated / total
return {
"allocated_gb": round(allocated, 2),
"total_gb": round(total, 2),
"free_gb": round(total - allocated, 2),
"usage_ratio": round(usage_ratio, 3),
"available": usage_ratio < 0.9
}
def load_inpainting_pipeline(
self,
conditioning_type: str = "canny",
progress_callback: Optional[Callable[[str, int], None]] = None
) -> Tuple[bool, str]:
"""
Load the ControlNet inpainting pipeline.
Implements mutual exclusion with background generation pipeline.
Only one pipeline can be loaded at a time.
Parameters
----------
conditioning_type : str
Type of ControlNet conditioning: "canny" or "depth"
progress_callback : callable, optional
Function(message, percentage) for progress updates
Returns
-------
tuple
(success: bool, error_message: str)
"""
if self.is_initialized and self._current_conditioning_type == conditioning_type:
logger.info(f"Inpainting pipeline already loaded with {conditioning_type}")
return True, ""
logger.info(f"Loading inpainting pipeline with {conditioning_type} conditioning...")
try:
self._memory_cleanup(aggressive=True)
if progress_callback:
progress_callback("Preparing to load inpainting models...", 5)
# Unload existing pipeline if different conditioning type
if self._inpaint_pipeline is not None:
self._unload_pipeline()
# Use ControlNet inpainting by default
use_controlnet_inpaint = True
logger.info("Using StableDiffusionXLControlNetInpaintPipeline")
if progress_callback:
progress_callback("Loading ControlNet model...", 20)
# Load appropriate ControlNet
dtype = torch.float16 if self.device == "cuda" else torch.float32
controlnet = None
if use_controlnet_inpaint:
if conditioning_type == "canny":
controlnet = ControlNetModel.from_pretrained(
self.CONTROLNET_CANNY_MODEL,
torch_dtype=dtype,
use_safetensors=True
)
self._controlnet_canny = controlnet
logger.info("Loaded ControlNet Canny model")
elif conditioning_type == "depth":
controlnet = ControlNetModel.from_pretrained(
self.CONTROLNET_DEPTH_MODEL,
torch_dtype=dtype,
use_safetensors=True
)
self._controlnet_depth = controlnet
# Load depth estimator
if progress_callback:
progress_callback("Loading depth estimation model...", 35)
self._load_depth_estimator()
logger.info("Loaded ControlNet Depth model")
else:
raise ValueError(f"Unknown conditioning type: {conditioning_type}")
else:
# Skip ControlNet loading for fallback mode
logger.info(f"Skipping ControlNet loading (fallback mode)")
if progress_callback:
progress_callback("Loading SDXL Inpainting pipeline...", 50)
# Load the inpainting pipeline
if use_controlnet_inpaint and controlnet is not None:
self._inpaint_pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
self.BASE_MODEL,
controlnet=controlnet,
torch_dtype=dtype,
use_safetensors=True,
variant="fp16" if dtype == torch.float16 else None
)
else:
# Fallback: Use dedicated inpainting model without ControlNet
self._inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=dtype,
use_safetensors=True,
variant="fp16" if dtype == torch.float16 else None
)
self._use_controlnet = False
# Track ControlNet usage
self._use_controlnet = use_controlnet_inpaint and controlnet is not None
if progress_callback:
progress_callback("Configuring scheduler...", 70)
# Configure scheduler for faster generation
self._inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
self._inpaint_pipeline.scheduler.config
)
# Move to device
self._inpaint_pipeline = self._inpaint_pipeline.to(self.device)
if progress_callback:
progress_callback("Applying optimizations...", 85)
# Apply memory optimizations
self._apply_pipeline_optimizations()
# Set eval mode
self._inpaint_pipeline.unet.eval()
if hasattr(self._inpaint_pipeline, 'vae'):
self._inpaint_pipeline.vae.eval()
self.is_initialized = True
self._current_conditioning_type = conditioning_type if self._use_controlnet else "none"
if progress_callback:
progress_callback("Inpainting pipeline ready!", 100)
# Log memory status
mem_status = self._check_memory_status()
logger.info(f"Pipeline loaded. GPU memory: {mem_status.get('allocated_gb', 0):.1f}GB used")
return True, ""
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to load inpainting pipeline: {error_msg}")
traceback.print_exc()
self._unload_pipeline()
return False, error_msg
def _load_depth_estimator(self) -> None:
"""
Load depth estimation model with fallback strategy.
Tries Depth-Anything first, falls back to MiDaS if unavailable.
"""
try:
logger.info(f"Attempting to load depth model: {self.DEPTH_MODEL_PRIMARY}")
self._depth_processor = AutoImageProcessor.from_pretrained(
self.DEPTH_MODEL_PRIMARY
)
self._depth_estimator = AutoModelForDepthEstimation.from_pretrained(
self.DEPTH_MODEL_PRIMARY,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
self._depth_estimator.to(self.device)
self._depth_estimator.eval()
logger.info("Successfully loaded Depth-Anything model")
except Exception as e:
logger.warning(f"Primary depth model failed: {e}, trying fallback...")
try:
self._depth_processor = DPTImageProcessor.from_pretrained(
self.DEPTH_MODEL_FALLBACK
)
self._depth_estimator = DPTForDepthEstimation.from_pretrained(
self.DEPTH_MODEL_FALLBACK,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
self._depth_estimator.to(self.device)
self._depth_estimator.eval()
logger.info("Successfully loaded MiDaS fallback model")
except Exception as fallback_e:
logger.error(f"Fallback depth model also failed: {fallback_e}")
raise RuntimeError("Unable to load any depth estimation model")
def _apply_pipeline_optimizations(self) -> None:
"""Apply memory and performance optimizations to the pipeline."""
if self._inpaint_pipeline is None:
return
# Try xformers first
try:
self._inpaint_pipeline.enable_xformers_memory_efficient_attention()
logger.info("Enabled xformers memory efficient attention")
except Exception:
try:
self._inpaint_pipeline.enable_attention_slicing()
logger.info("Enabled attention slicing")
except Exception:
logger.warning("No attention optimization available")
# VAE optimizations
if self.config.enable_vae_tiling:
if hasattr(self._inpaint_pipeline, 'enable_vae_tiling'):
self._inpaint_pipeline.enable_vae_tiling()
logger.debug("Enabled VAE tiling")
if hasattr(self._inpaint_pipeline, 'enable_vae_slicing'):
self._inpaint_pipeline.enable_vae_slicing()
logger.debug("Enabled VAE slicing")
def _unload_pipeline(self) -> None:
"""Unload the inpainting pipeline and free memory."""
logger.info("Unloading inpainting pipeline...")
if self._inpaint_pipeline is not None:
del self._inpaint_pipeline
self._inpaint_pipeline = None
if self._controlnet_canny is not None:
del self._controlnet_canny
self._controlnet_canny = None
if self._controlnet_depth is not None:
del self._controlnet_depth
self._controlnet_depth = None
if self._depth_estimator is not None:
del self._depth_estimator
self._depth_estimator = None
if self._depth_processor is not None:
del self._depth_processor
self._depth_processor = None
self.is_initialized = False
self._current_conditioning_type = None
self._cached_latents = None
self._memory_cleanup(aggressive=True)
logger.info("Inpainting pipeline unloaded")
def prepare_control_image(
self,
image: Image.Image,
mode: str = "canny",
mask: Optional[Image.Image] = None,
preserve_structure: bool = False
) -> Image.Image:
"""
Generate ControlNet conditioning image.
Parameters
----------
image : PIL.Image
Input image
mode : str
Conditioning mode: "canny" or "depth"
mask : PIL.Image, optional
If provided, can suppress edges in masked region (when preserve_structure=False).
preserve_structure : bool
If True, keep edges in masked region (for color change tasks).
If False, suppress edges in masked region (for replacement/removal tasks).
Returns
-------
PIL.Image
Generated control image (edges or depth map)
"""
logger.info(f"Preparing control image with mode: {mode}, preserve_structure: {preserve_structure}")
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
img_array = np.array(image)
if mode == "canny":
canny_image = self._generate_canny_edges(img_array)
# Mask-aware processing: suppress edges in masked region ONLY if not preserving structure
if mask is not None and not preserve_structure:
canny_array = np.array(canny_image)
mask_array = np.array(mask.convert('L'))
# In masked region, completely suppress Canny edges
# This allows complete replacement/removal of the object
mask_region = mask_array > 128 # White = masked area
canny_array[mask_region] = 0
canny_image = Image.fromarray(canny_array)
logger.info("Suppressed edges in masked region for replacement/removal")
elif preserve_structure:
logger.info("Preserving edges in masked region for color change")
return canny_image
elif mode == "depth":
return self._generate_depth_map(image)
else:
raise ValueError(f"Unknown control mode: {mode}")
def _generate_canny_edges(self, img_array: np.ndarray) -> Image.Image:
"""
Generate Canny edge detection image.
Parameters
----------
img_array : np.ndarray
Input image as RGB numpy array
Returns
-------
PIL.Image
Edge detection result as grayscale image
"""
# Convert to grayscale
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# Apply Gaussian blur to reduce noise
blurred = cv2.GaussianBlur(gray, (5, 5), 1.4)
# Canny edge detection
edges = cv2.Canny(
blurred,
self.config.canny_low_threshold,
self.config.canny_high_threshold
)
# Convert to 3-channel for ControlNet
edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
logger.debug(f"Generated Canny edges with thresholds "
f"{self.config.canny_low_threshold}/{self.config.canny_high_threshold}")
return Image.fromarray(edges_3ch)
def _generate_depth_map(self, image: Image.Image) -> Image.Image:
"""
Generate depth map using depth estimation model.
Parameters
----------
image : PIL.Image
Input RGB image
Returns
-------
PIL.Image
Depth map as grayscale image
"""
if self._depth_estimator is None or self._depth_processor is None:
raise RuntimeError("Depth estimator not loaded")
# Preprocess
inputs = self._depth_processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Inference
with torch.no_grad():
outputs = self._depth_estimator(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1], # (H, W)
mode="bicubic",
align_corners=False
)
# Normalize to 0-255
depth_array = prediction.squeeze().cpu().numpy()
depth_min = depth_array.min()
depth_max = depth_array.max()
if depth_max - depth_min > 0:
depth_normalized = ((depth_array - depth_min) / (depth_max - depth_min) * 255)
else:
depth_normalized = np.zeros_like(depth_array)
depth_normalized = depth_normalized.astype(np.uint8)
# Convert to 3-channel for ControlNet
depth_3ch = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
logger.debug(f"Generated depth map, range: {depth_min:.2f} - {depth_max:.2f}")
return Image.fromarray(depth_3ch)
def prepare_mask(
self,
mask: Image.Image,
target_size: Tuple[int, int],
feather_radius: Optional[int] = None
) -> Tuple[Image.Image, Dict[str, Any]]:
"""
Prepare and validate mask for inpainting.
Parameters
----------
mask : PIL.Image
Input mask (white = inpaint area)
target_size : tuple
Target (width, height) to match input image
feather_radius : int, optional
Feathering radius in pixels. Uses config default if None.
Returns
-------
tuple
(processed_mask, validation_info)
Raises
------
ValueError
If mask coverage is outside acceptable range
"""
feather = feather_radius if feather_radius is not None else self.config.feather_radius
# Convert to grayscale
if mask.mode != 'L':
mask = mask.convert('L')
# Resize to match target
if mask.size != target_size:
mask = mask.resize(target_size, Image.LANCZOS)
# Convert to array for processing
mask_array = np.array(mask)
# Calculate coverage
total_pixels = mask_array.size
white_pixels = np.count_nonzero(mask_array > 127)
coverage = white_pixels / total_pixels
validation_info = {
"coverage": coverage,
"white_pixels": white_pixels,
"total_pixels": total_pixels,
"feather_radius": feather,
"valid": True,
"warning": ""
}
# Validate coverage
if coverage < self.config.min_mask_coverage:
validation_info["valid"] = False
validation_info["warning"] = (
f"Mask coverage too low ({coverage:.1%}). "
f"Please select a larger area to inpaint."
)
logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.config.min_mask_coverage:.1%}")
elif coverage > self.config.max_mask_coverage:
validation_info["valid"] = False
validation_info["warning"] = (
f"Mask coverage too high ({coverage:.1%}). "
f"Consider using background generation instead."
)
logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.config.max_mask_coverage:.1%}")
# Apply feathering
if feather > 0:
mask_array = cv2.GaussianBlur(
mask_array,
(feather * 2 + 1, feather * 2 + 1),
feather / 2
)
logger.debug(f"Applied {feather}px feathering to mask")
processed_mask = Image.fromarray(mask_array, mode='L')
return processed_mask, validation_info
def enhance_prompt_for_inpainting(
self,
prompt: str,
image: Image.Image,
mask: Image.Image
) -> Tuple[str, str]:
"""
Enhance prompt based on non-masked region analysis.
Analyzes the surrounding context to generate appropriate
lighting and color descriptors.
Parameters
----------
prompt : str
User-provided prompt
image : PIL.Image
Original image
mask : PIL.Image
Inpainting mask
Returns
-------
tuple
(enhanced_prompt, negative_prompt)
"""
logger.info("Enhancing prompt for inpainting context...")
# Convert to arrays
img_array = np.array(image.convert('RGB'))
mask_array = np.array(mask.convert('L'))
# Analyze non-masked regions
non_masked = mask_array < 127
if not np.any(non_masked):
# No context available
enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
negative_prompt = self._get_inpainting_negative_prompt()
return enhanced_prompt, negative_prompt
# Extract context pixels
context_pixels = img_array[non_masked]
# Convert to Lab for analysis
context_lab = cv2.cvtColor(
context_pixels.reshape(-1, 1, 3),
cv2.COLOR_RGB2LAB
).reshape(-1, 3)
# Use robust statistics (median) to avoid outlier influence
median_l = np.median(context_lab[:, 0])
median_a = np.median(context_lab[:, 1])
median_b = np.median(context_lab[:, 2])
# Analyze lighting conditions
lighting_descriptors = []
if median_l > 170:
lighting_descriptors.append("bright")
elif median_l > 130:
lighting_descriptors.append("well-lit")
elif median_l > 80:
lighting_descriptors.append("moderate lighting")
else:
lighting_descriptors.append("dim lighting")
# Analyze color temperature (b channel: blue(-) to yellow(+))
if median_b > 140:
lighting_descriptors.append("warm golden tones")
elif median_b > 120:
lighting_descriptors.append("warm afternoon light")
elif median_b < 110:
lighting_descriptors.append("cool neutral tones")
# Calculate saturation from context
hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
median_saturation = np.median(hsv[:, :, 1])
if median_saturation > 150:
lighting_descriptors.append("vibrant colors")
elif median_saturation < 80:
lighting_descriptors.append("subtle muted colors")
# Build enhanced prompt
lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
quality_suffix = "high quality, detailed, photorealistic, seamless integration"
if lighting_desc:
enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
else:
enhanced_prompt = f"{prompt}, {quality_suffix}"
negative_prompt = self._get_inpainting_negative_prompt()
logger.info(f"Enhanced prompt with context: {lighting_desc}")
return enhanced_prompt, negative_prompt
def _get_inpainting_negative_prompt(self) -> str:
"""Get standard negative prompt for inpainting."""
return (
"inconsistent lighting, wrong perspective, mismatched colors, "
"visible seams, blending artifacts, color bleeding, "
"blurry, low quality, distorted, deformed, "
"harsh edges, unnatural transition"
)
def execute_inpainting(
self,
image: Image.Image,
mask: Image.Image,
prompt: str,
preview_only: bool = False,
seed: Optional[int] = None,
progress_callback: Optional[Callable[[str, int], None]] = None,
**kwargs
) -> InpaintingResult:
"""
Execute the inpainting operation.
Implements two-stage generation: fast preview followed by
full quality generation if requested.
Parameters
----------
image : PIL.Image
Original image to inpaint
mask : PIL.Image
Inpainting mask (white = area to regenerate)
prompt : str
Text description of desired content
preview_only : bool
If True, only generate preview (faster)
seed : int, optional
Random seed for reproducibility
progress_callback : callable, optional
Progress update function(message, percentage)
**kwargs
Additional parameters:
- controlnet_conditioning_scale: float
- feather_radius: int
- num_inference_steps: int
- guidance_scale: float
Returns
-------
InpaintingResult
Result container with generated images and metadata
"""
start_time = time.time()
if not self.is_initialized:
return InpaintingResult(
success=False,
error_message="Inpainting pipeline not initialized. Call load_inpainting_pipeline() first."
)
logger.info(f"Starting inpainting: prompt='{prompt[:50]}...', preview_only={preview_only}")
try:
# Update config with kwargs
conditioning_scale = kwargs.get(
'controlnet_conditioning_scale',
self.config.controlnet_conditioning_scale
)
feather_radius = kwargs.get('feather_radius', self.config.feather_radius)
strength = kwargs.get('strength', self.config.strength)
preserve_structure = kwargs.get('preserve_structure_in_mask', False)
if progress_callback:
progress_callback("Preparing images...", 5)
# Prepare image
if image.mode != 'RGB':
image = image.convert('RGB')
# Ensure dimensions are multiple of 8
width, height = image.size
new_width = (width // 8) * 8
new_height = (height // 8) * 8
if new_width != width or new_height != height:
image = image.resize((new_width, new_height), Image.LANCZOS)
# Check and potentially reduce resolution for memory
max_res = self.config.max_resolution
if max(new_width, new_height) > max_res:
scale = max_res / max(new_width, new_height)
new_width = int(new_width * scale) // 8 * 8
new_height = int(new_height * scale) // 8 * 8
image = image.resize((new_width, new_height), Image.LANCZOS)
logger.info(f"Reduced resolution to {new_width}x{new_height} for memory")
# Prepare mask
if progress_callback:
progress_callback("Processing mask...", 10)
processed_mask, mask_info = self.prepare_mask(
mask,
(new_width, new_height),
feather_radius
)
if not mask_info["valid"]:
return InpaintingResult(
success=False,
error_message=mask_info["warning"]
)
# Generate control image
if progress_callback:
progress_callback("Generating control image...", 20)
control_image = self.prepare_control_image(
image,
self._current_conditioning_type,
mask=processed_mask,
preserve_structure=preserve_structure # True for color change, False for replacement/removal
)
# Conditional prompt enhancement based on template
# Check if we should enhance the prompt or use it directly
should_enhance = kwargs.get('enhance_prompt', False) # Default: no enhancement
if should_enhance:
if progress_callback:
progress_callback("Enhancing prompt...", 25)
enhanced_prompt, negative_prompt = self.enhance_prompt_for_inpainting(
prompt, image, processed_mask
)
logger.info(f"Prompt enhanced with OpenCLIP context")
else:
# Use prompt directly without enhancement
enhanced_prompt = prompt
negative_prompt = self._get_inpainting_negative_prompt()
logger.info("Prompt enhancement disabled for this template")
# Setup generator for reproducibility
if seed is None:
seed = int(time.time() * 1000) % (2**32)
self._last_seed = seed
generator = torch.Generator(device=self.device).manual_seed(seed)
# Check if running on Hugging Face Spaces
is_spaces = os.getenv('SPACE_ID') is not None
# Stage 1: Preview generation
# On Spaces, skip preview to save time (300s hard limit)
preview_result = None
if preview_only or not is_spaces:
if progress_callback:
progress_callback("Generating preview...", 30)
# Optimize preview steps for Hugging Face Spaces
preview_steps = self.config.preview_steps
if is_spaces:
# On Spaces, use minimal preview steps
preview_steps = min(preview_steps, 8)
logger.debug(f"Spaces environment - using {preview_steps} preview steps")
preview_result = self._generate_inpaint(
image=image,
mask=processed_mask,
control_image=control_image,
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
num_inference_steps=preview_steps,
guidance_scale=self.config.preview_guidance_scale,
controlnet_conditioning_scale=conditioning_scale,
strength=strength,
generator=generator
)
else:
logger.debug("Spaces environment - skipping preview to fit 300s limit")
if preview_only:
generation_time = time.time() - start_time
return InpaintingResult(
success=True,
preview_image=preview_result,
control_image=control_image,
generation_time=generation_time,
metadata={
"seed": seed,
"prompt": enhanced_prompt,
"conditioning_type": self._current_conditioning_type,
"conditioning_scale": conditioning_scale,
"preview_only": True
}
)
# Stage 2: Full quality generation
if progress_callback:
progress_callback("Generating full quality...", 60)
# Use same seed for reproducibility
generator = torch.Generator(device=self.device).manual_seed(seed)
num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
# Optimize for Hugging Face Spaces ZeroGPU (stateless, 300s hard limit)
if is_spaces:
# ZeroGPU timing breakdown with model caching (actual measurements):
# - Model loading from cache: ~60s (cached models, CPU to GPU transfer)
# - Inference: ~28-29s/step (observed on shared H200)
# - Blending & overhead: ~35s
# - Platform limit: 300s hard limit (Pro tier)
#
# Strategy with unified 10-step approach:
# - Skip preview completely (done above)
# - Use 10 steps for balance of quality and speed
# - Time budget: 60s (load) + 285s (10 steps) + 35s (blend) = 380s
# - Note: Still may timeout, but parameter optimization is more important than step count
# - Quality comes from correct conditioning_scale, not high step count
spaces_max_steps = 10 # Optimized: 10 steps sufficient with proper parameters
if num_steps > spaces_max_steps:
num_steps = spaces_max_steps
logger.debug(f"Spaces deployment: using {num_steps} steps (optimized for parameter quality)")
full_result = self._generate_inpaint(
image=image,
mask=processed_mask,
control_image=control_image,
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_steps,
guidance_scale=guidance,
controlnet_conditioning_scale=conditioning_scale,
strength=strength,
generator=generator
)
if progress_callback:
progress_callback("Blending result...", 90)
# Blend result
blended = self.blend_result(image, full_result, processed_mask)
generation_time = time.time() - start_time
if progress_callback:
progress_callback("Complete!", 100)
return InpaintingResult(
success=True,
result_image=full_result,
preview_image=preview_result,
control_image=control_image,
blended_image=blended,
generation_time=generation_time,
metadata={
"seed": seed,
"prompt": enhanced_prompt,
"negative_prompt": negative_prompt,
"conditioning_type": self._current_conditioning_type,
"conditioning_scale": conditioning_scale,
"strength": strength,
"preserve_structure": preserve_structure,
"num_inference_steps": num_steps,
"guidance_scale": guidance,
"feather_radius": feather_radius,
"mask_coverage": mask_info["coverage"],
"preview_only": False
}
)
except torch.cuda.OutOfMemoryError:
logger.error("CUDA out of memory during inpainting")
self._memory_cleanup(aggressive=True)
return InpaintingResult(
success=False,
error_message="GPU memory exhausted. Try reducing image size or closing other applications."
)
except Exception as e:
logger.error(f"Inpainting failed: {e}")
logger.error(traceback.format_exc())
return InpaintingResult(
success=False,
error_message=f"Inpainting failed: {str(e)}"
)
def _generate_inpaint(
self,
image: Image.Image,
mask: Image.Image,
control_image: Image.Image,
prompt: str,
negative_prompt: str,
num_inference_steps: int,
guidance_scale: float,
controlnet_conditioning_scale: float,
strength: float,
generator: torch.Generator
) -> Image.Image:
"""
Internal method to run the inpainting pipeline.
Supports both ControlNet and non-ControlNet pipelines.
Parameters
----------
image : PIL.Image
Original image
mask : PIL.Image
Processed mask
control_image : PIL.Image
ControlNet conditioning image (ignored if ControlNet not available)
prompt : str
Enhanced prompt
negative_prompt : str
Negative prompt
num_inference_steps : int
Number of denoising steps
guidance_scale : float
Classifier-free guidance scale
controlnet_conditioning_scale : float
ControlNet influence strength (ignored if ControlNet not available)
strength : float
Inpainting strength (0.0-1.0). 1.0 = fully repaint masked area.
generator : torch.Generator
Random generator for reproducibility
Returns
-------
PIL.Image
Generated image
"""
with torch.inference_mode():
if self._use_controlnet:
# Full ControlNet inpainting pipeline
result = self._inpaint_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask,
control_image=control_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
strength=strength,
generator=generator
)
else:
# Fallback: Standard SDXL inpainting without ControlNet
result = self._inpaint_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=strength,
generator=generator
)
return result.images[0]
def blend_result(
self,
original: Image.Image,
generated: Image.Image,
mask: Image.Image
) -> Image.Image:
"""
Blend generated content with original image.
Uses linear color space blending for accurate results.
Parameters
----------
original : PIL.Image
Original image
generated : PIL.Image
Generated inpainted image
mask : PIL.Image
Blending mask (white = use generated)
Returns
-------
PIL.Image
Blended result
"""
logger.info("Blending inpainting result...")
# Ensure same size
if generated.size != original.size:
generated = generated.resize(original.size, Image.LANCZOS)
if mask.size != original.size:
mask = mask.resize(original.size, Image.LANCZOS)
# Convert to arrays
orig_array = np.array(original.convert('RGB')).astype(np.float32)
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
# sRGB to linear conversion
def srgb_to_linear(img):
img_norm = img / 255.0
return np.where(
img_norm <= 0.04045,
img_norm / 12.92,
np.power((img_norm + 0.055) / 1.055, 2.4)
)
def linear_to_srgb(img):
img_clipped = np.clip(img, 0, 1)
return np.where(
img_clipped <= 0.0031308,
12.92 * img_clipped,
1.055 * np.power(img_clipped, 1/2.4) - 0.055
)
# Convert to linear space
orig_linear = srgb_to_linear(orig_array)
gen_linear = srgb_to_linear(gen_array)
# Alpha blending in linear space
alpha = mask_array[:, :, np.newaxis]
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
# Convert back to sRGB
result_srgb = linear_to_srgb(result_linear)
result_array = (result_srgb * 255).astype(np.uint8)
logger.debug("Blending completed in linear color space")
return Image.fromarray(result_array)
def execute_with_auto_optimization(
self,
image: Image.Image,
mask: Image.Image,
prompt: str,
quality_checker: Any,
progress_callback: Optional[Callable[[str, int], None]] = None,
**kwargs
) -> InpaintingResult:
"""
Execute inpainting with automatic quality-based optimization.
Retries with adjusted parameters if quality score is below threshold.
Parameters
----------
image : PIL.Image
Original image
mask : PIL.Image
Inpainting mask
prompt : str
Text prompt
quality_checker : QualityChecker
Quality assessment instance
progress_callback : callable, optional
Progress update function
**kwargs
Additional inpainting parameters
Returns
-------
InpaintingResult
Best result achieved (may include retry information)
"""
if not self.config.enable_auto_optimization:
return self.execute_inpainting(
image, mask, prompt,
progress_callback=progress_callback,
**kwargs
)
best_result = None
best_score = 0.0
retry_count = 0
prev_score = 0.0
# Mutable parameters for optimization
current_feather = kwargs.get('feather_radius', self.config.feather_radius)
current_scale = kwargs.get(
'controlnet_conditioning_scale',
self.config.controlnet_conditioning_scale
)
current_guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
current_prompt = prompt
while retry_count <= self.config.max_optimization_retries:
if progress_callback and retry_count > 0:
progress_callback(f"Optimizing (attempt {retry_count + 1})...", 5)
# Execute inpainting
result = self.execute_inpainting(
image, mask, current_prompt,
preview_only=False,
feather_radius=current_feather,
controlnet_conditioning_scale=current_scale,
guidance_scale=current_guidance,
progress_callback=progress_callback if retry_count == 0 else None,
**{k: v for k, v in kwargs.items()
if k not in ['feather_radius', 'controlnet_conditioning_scale',
'guidance_scale']}
)
if not result.success:
return result
# Evaluate quality
if result.blended_image is not None:
quality_results = quality_checker.run_all_checks(
foreground=image,
background=result.result_image,
mask=mask,
combined=result.blended_image
)
quality_score = quality_results.get("overall_score", 0)
else:
quality_score = 50.0 # Default if no blended image
result.quality_score = quality_score
result.quality_details = quality_results if result.blended_image else {}
result.retries = retry_count
logger.info(f"Quality score: {quality_score:.1f} (attempt {retry_count + 1})")
# Track best result
if quality_score > best_score:
best_score = quality_score
best_result = result
# Check if quality is acceptable
if quality_score >= self.config.min_quality_score:
logger.info(f"Quality threshold met: {quality_score:.1f}")
return best_result
# Check for minimal improvement (early termination)
if retry_count > 0 and abs(quality_score - prev_score) < 5.0:
logger.info("Minimal improvement, stopping optimization")
return best_result
prev_score = quality_score
retry_count += 1
if retry_count > self.config.max_optimization_retries:
break
# Adjust parameters based on quality issues
checks = quality_results.get("checks", {})
edge_score = checks.get("edge_continuity", {}).get("score", 100)
harmony_score = checks.get("color_harmony", {}).get("score", 100)
if edge_score < 60:
# Edge issues: increase feathering, decrease control strength
current_feather = min(20, current_feather + 3)
current_scale = max(0.5, current_scale - 0.1)
logger.debug(f"Adjusting for edges: feather={current_feather}, scale={current_scale}")
if harmony_score < 60:
# Color harmony issues: emphasize consistency in prompt
if "color consistent" not in current_prompt.lower():
current_prompt = f"{current_prompt}, color consistent with surroundings, matching lighting"
current_guidance = min(12.0, current_guidance + 1.0)
logger.debug(f"Adjusting for harmony: guidance={current_guidance}")
if edge_score < 60 and harmony_score < 60:
# Both issues: stronger guidance
current_guidance = min(12.0, current_guidance + 1.5)
logger.info(f"Optimization complete. Best score: {best_score:.1f}")
return best_result
def get_status(self) -> Dict[str, Any]:
"""
Get current module status.
Returns
-------
dict
Status information including initialization state and memory usage
"""
status = {
"initialized": self.is_initialized,
"device": self.device,
"conditioning_type": self._current_conditioning_type,
"last_seed": self._last_seed,
"config": {
"controlnet_conditioning_scale": self.config.controlnet_conditioning_scale,
"feather_radius": self.config.feather_radius,
"num_inference_steps": self.config.num_inference_steps,
"guidance_scale": self.config.guidance_scale
}
}
status["memory"] = self._check_memory_status()
return status