import logging import numpy as np import cv2 from PIL import Image from typing import Dict, Any, Tuple, Optional from dataclasses import dataclass logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @dataclass class QualityResult: """Result of a quality check.""" score: float # 0-100 passed: bool issue: str details: Dict[str, Any] class QualityChecker: """ Automated quality validation system for generated images. Provides checks for mask coverage, edge continuity, and color harmony. """ # Quality thresholds THRESHOLD_PASS = 70 THRESHOLD_WARNING = 50 def __init__(self, strictness: str = "standard"): """ Initialize QualityChecker. Args: strictness: Quality check strictness level "lenient" - Only check fatal issues "standard" - All checks with moderate thresholds "strict" - High standards required """ self.strictness = strictness self._set_thresholds() def _set_thresholds(self): """Set quality thresholds based on strictness level.""" if self.strictness == "lenient": self.min_coverage = 0.03 # 3% self.min_edge_score = 40 self.min_harmony_score = 40 elif self.strictness == "strict": self.min_coverage = 0.10 # 10% self.min_edge_score = 75 self.min_harmony_score = 75 else: # standard self.min_coverage = 0.05 # 5% self.min_edge_score = 60 self.min_harmony_score = 60 def check_mask_coverage(self, mask: Image.Image) -> QualityResult: """ Verify mask coverage is adequate. Args: mask: Grayscale mask image (L mode) Returns: QualityResult with coverage analysis """ try: mask_array = np.array(mask.convert('L')) height, width = mask_array.shape total_pixels = height * width # Count foreground pixels fg_pixels = np.count_nonzero(mask_array > 127) coverage_ratio = fg_pixels / total_pixels # Check for isolated small regions (noise) _, binary = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY) num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8) # Count significant regions (> 1% of image) min_region_size = total_pixels * 0.01 significant_regions = sum(1 for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] > min_region_size) # Calculate fragmentation (many small regions = bad) fragmentation_penalty = max(0, (num_labels - 1 - significant_regions) * 2) # Score calculation coverage_score = min(100, coverage_ratio * 200) # 50% coverage = 100 score final_score = max(0, coverage_score - fragmentation_penalty) # Determine pass/fail passed = coverage_ratio >= self.min_coverage and significant_regions >= 1 issue = "" if coverage_ratio < self.min_coverage: issue = f"Low foreground coverage ({coverage_ratio:.1%})" elif significant_regions == 0: issue = "No significant foreground regions detected" elif fragmentation_penalty > 20: issue = f"Fragmented mask with {num_labels - 1} isolated regions" return QualityResult( score=final_score, passed=passed, issue=issue, details={ "coverage_ratio": coverage_ratio, "foreground_pixels": fg_pixels, "total_regions": num_labels - 1, "significant_regions": significant_regions } ) except Exception as e: logger.error(f"❌ Mask coverage check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def check_edge_continuity(self, mask: Image.Image) -> QualityResult: """ Check if mask edges are continuous and smooth. Args: mask: Grayscale mask image Returns: QualityResult with edge analysis """ try: mask_array = np.array(mask.convert('L')) # Find edges using morphological gradient kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) gradient = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, kernel) # Get edge pixels edge_pixels = gradient > 20 edge_count = np.count_nonzero(edge_pixels) if edge_count == 0: return QualityResult( score=50, passed=False, issue="No edges detected in mask", details={"edge_count": 0} ) # Check edge smoothness using Laplacian laplacian = cv2.Laplacian(mask_array, cv2.CV_64F) edge_laplacian = np.abs(laplacian[edge_pixels]) # High Laplacian values indicate jagged edges smoothness = 100 - min(100, np.std(edge_laplacian) * 0.5) # Check for gaps in edges # Dilate and erode to find disconnections dilated = cv2.dilate(gradient, kernel, iterations=1) eroded = cv2.erode(dilated, kernel, iterations=1) gaps = cv2.subtract(dilated, eroded) gap_ratio = np.count_nonzero(gaps) / max(edge_count, 1) # Calculate final score gap_penalty = min(40, gap_ratio * 100) final_score = max(0, smoothness - gap_penalty) passed = final_score >= self.min_edge_score issue = "" if final_score < self.min_edge_score: if smoothness < 60: issue = "Jagged or rough edges detected" elif gap_ratio > 0.3: issue = "Discontinuous edges with gaps" else: issue = "Poor edge quality" return QualityResult( score=final_score, passed=passed, issue=issue, details={ "edge_count": edge_count, "smoothness": smoothness, "gap_ratio": gap_ratio } ) except Exception as e: logger.error(f"❌ Edge continuity check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def check_color_harmony( self, foreground: Image.Image, background: Image.Image, mask: Image.Image ) -> QualityResult: """ Evaluate color harmony between foreground and background. Args: foreground: Original foreground image background: Generated background image mask: Combination mask Returns: QualityResult with harmony analysis """ try: fg_array = np.array(foreground.convert('RGB')) bg_array = np.array(background.convert('RGB')) mask_array = np.array(mask.convert('L')) # Get foreground and background regions fg_region = mask_array > 127 bg_region = mask_array <= 127 if not np.any(fg_region) or not np.any(bg_region): return QualityResult( score=50, passed=True, issue="Cannot analyze harmony - insufficient regions", details={} ) # Convert to LAB for perceptual analysis fg_lab = cv2.cvtColor(fg_array, cv2.COLOR_RGB2LAB).astype(np.float32) bg_lab = cv2.cvtColor(bg_array, cv2.COLOR_RGB2LAB).astype(np.float32) # Calculate average colors fg_avg_l = np.mean(fg_lab[fg_region, 0]) fg_avg_a = np.mean(fg_lab[fg_region, 1]) fg_avg_b = np.mean(fg_lab[fg_region, 2]) bg_avg_l = np.mean(bg_lab[bg_region, 0]) bg_avg_a = np.mean(bg_lab[bg_region, 1]) bg_avg_b = np.mean(bg_lab[bg_region, 2]) # Calculate color differences delta_l = abs(fg_avg_l - bg_avg_l) delta_a = abs(fg_avg_a - bg_avg_a) delta_b = abs(fg_avg_b - bg_avg_b) # Overall color difference (Delta E approximation) delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2) # Score calculation # Moderate difference is good (20-60 Delta E) # Too similar or too different is problematic if delta_e < 10: harmony_score = 60 # Too similar, foreground may get lost issue = "Foreground and background colors too similar" elif delta_e > 80: harmony_score = 50 # Too different, may look unnatural issue = "High color contrast may look unnatural" elif 20 <= delta_e <= 60: harmony_score = 100 # Ideal range issue = "" else: harmony_score = 80 issue = "" # Check for extreme contrast (very dark fg on very bright bg or vice versa) brightness_contrast = abs(fg_avg_l - bg_avg_l) if brightness_contrast > 100: harmony_score = max(40, harmony_score - 30) issue = "Extreme brightness contrast between foreground and background" passed = harmony_score >= self.min_harmony_score return QualityResult( score=harmony_score, passed=passed, issue=issue, details={ "delta_e": delta_e, "delta_l": delta_l, "delta_a": delta_a, "delta_b": delta_b, "fg_luminance": fg_avg_l, "bg_luminance": bg_avg_l } ) except Exception as e: logger.error(f"❌ Color harmony check failed: {e}") return QualityResult(score=0, passed=False, issue=str(e), details={}) def run_all_checks( self, foreground: Image.Image, background: Image.Image, mask: Image.Image, combined: Optional[Image.Image] = None ) -> Dict[str, Any]: """ Run all quality checks and return comprehensive results. Args: foreground: Original foreground image background: Generated background mask: Combination mask combined: Final combined image (optional) Returns: Dictionary with all check results and overall score """ logger.info("🔍 Running quality checks...") results = { "checks": {}, "overall_score": 0, "passed": True, "warnings": [], "errors": [] } # Run individual checks coverage_result = self.check_mask_coverage(mask) results["checks"]["mask_coverage"] = { "score": coverage_result.score, "passed": coverage_result.passed, "issue": coverage_result.issue, "details": coverage_result.details } edge_result = self.check_edge_continuity(mask) results["checks"]["edge_continuity"] = { "score": edge_result.score, "passed": edge_result.passed, "issue": edge_result.issue, "details": edge_result.details } harmony_result = self.check_color_harmony(foreground, background, mask) results["checks"]["color_harmony"] = { "score": harmony_result.score, "passed": harmony_result.passed, "issue": harmony_result.issue, "details": harmony_result.details } # Calculate overall score (weighted average) weights = { "mask_coverage": 0.4, "edge_continuity": 0.3, "color_harmony": 0.3 } total_score = ( coverage_result.score * weights["mask_coverage"] + edge_result.score * weights["edge_continuity"] + harmony_result.score * weights["color_harmony"] ) results["overall_score"] = round(total_score, 1) # Determine overall pass/fail results["passed"] = all([ coverage_result.passed, edge_result.passed, harmony_result.passed ]) # Collect warnings and errors for check_name, check_data in results["checks"].items(): if check_data["issue"]: if check_data["passed"]: results["warnings"].append(f"{check_name}: {check_data['issue']}") else: results["errors"].append(f"{check_name}: {check_data['issue']}") logger.info(f"📊 Quality check complete - Score: {results['overall_score']}, Passed: {results['passed']}") return results def get_quality_summary(self, results: Dict[str, Any]) -> str: """ Generate human-readable quality summary. Args: results: Results from run_all_checks Returns: Summary string """ score = results["overall_score"] passed = results["passed"] if score >= 90: grade = "Excellent" elif score >= 75: grade = "Good" elif score >= 60: grade = "Acceptable" elif score >= 40: grade = "Needs Improvement" else: grade = "Poor" summary = f"Quality: {grade} ({score:.0f}/100)" if results["errors"]: summary += f"\nIssues: {'; '.join(results['errors'])}" elif results["warnings"]: summary += f"\nNotes: {'; '.join(results['warnings'])}" return summary