Spaces:
Running
on
Zero
Running
on
Zero
Upload 10 files
Browse files- app.py +82 -0
- css_styles.py +513 -0
- image_blender.py +802 -0
- mask_generator.py +650 -0
- model_manager.py +293 -0
- quality_checker.py +409 -0
- requirements.txt +81 -0
- scene_templates.py +429 -0
- scene_weaver_core.py +808 -0
- ui_manager.py +513 -0
app.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import warnings
|
| 3 |
+
warnings.filterwarnings("ignore")
|
| 4 |
+
|
| 5 |
+
from ui_manager import UIManager
|
| 6 |
+
|
| 7 |
+
def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
|
| 8 |
+
"""Launch SceneWeaver Application"""
|
| 9 |
+
|
| 10 |
+
print("🎨 Starting SceneWeaver...")
|
| 11 |
+
print("✨ AI-Powered Image Background Generation")
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
# Test imports first
|
| 15 |
+
print("🔍 Testing imports...")
|
| 16 |
+
try:
|
| 17 |
+
# Test creating UIManager
|
| 18 |
+
print("🔍 Creating UIManager instance...")
|
| 19 |
+
ui = UIManager()
|
| 20 |
+
print("✅ UIManager instance created successfully")
|
| 21 |
+
|
| 22 |
+
# Launch UI
|
| 23 |
+
print("🚀 Launching interface...")
|
| 24 |
+
interface = ui.launch(share=share, debug=debug)
|
| 25 |
+
print("✅ Interface launched successfully")
|
| 26 |
+
return interface
|
| 27 |
+
|
| 28 |
+
except ImportError as import_error:
|
| 29 |
+
import traceback
|
| 30 |
+
print(f"❌ Import failed: {import_error}")
|
| 31 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 32 |
+
raise
|
| 33 |
+
|
| 34 |
+
except Exception as e:
|
| 35 |
+
import traceback
|
| 36 |
+
print(f"❌ Failed to launch: {e}")
|
| 37 |
+
print(f"Full traceback: {traceback.format_exc()}")
|
| 38 |
+
raise
|
| 39 |
+
|
| 40 |
+
def launch_ui(share: bool = True, debug: bool = False):
|
| 41 |
+
"""Convenience function for Jupyter notebooks"""
|
| 42 |
+
return launch_final_blend_sceneweaver(share=share, debug=debug)
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
"""Main entry point"""
|
| 46 |
+
|
| 47 |
+
# Check if running in Jupyter/Colab
|
| 48 |
+
try:
|
| 49 |
+
get_ipython()
|
| 50 |
+
is_jupyter = True
|
| 51 |
+
except NameError:
|
| 52 |
+
is_jupyter = False
|
| 53 |
+
|
| 54 |
+
if not is_jupyter and len(sys.argv) > 1 and not any('-f' in arg for arg in sys.argv):
|
| 55 |
+
# Command line mode with arguments
|
| 56 |
+
share = '--no-share' not in sys.argv
|
| 57 |
+
debug = '--debug' in sys.argv
|
| 58 |
+
else:
|
| 59 |
+
# Default mode
|
| 60 |
+
share = True
|
| 61 |
+
debug = False
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
interface = launch_final_blend_sceneweaver(share=share, debug=debug)
|
| 65 |
+
|
| 66 |
+
if not is_jupyter:
|
| 67 |
+
print("🛑 Press Ctrl+C to stop")
|
| 68 |
+
try:
|
| 69 |
+
interface.block_thread()
|
| 70 |
+
except KeyboardInterrupt:
|
| 71 |
+
print("👋 Stopped")
|
| 72 |
+
|
| 73 |
+
return interface
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"❌ Error: {e}")
|
| 77 |
+
if not is_jupyter:
|
| 78 |
+
sys.exit(1)
|
| 79 |
+
raise
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
css_styles.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class CSSStyles:
|
| 2 |
+
"""
|
| 3 |
+
CSS styling configuration for the SceneWeaver application.
|
| 4 |
+
Professional design system with clean typography and modern aesthetics.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
@staticmethod
|
| 8 |
+
def get_main_css() -> str:
|
| 9 |
+
"""
|
| 10 |
+
Get the main CSS styling for the application.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
Complete CSS string for Gradio interface styling
|
| 14 |
+
"""
|
| 15 |
+
return """
|
| 16 |
+
/* Import professional fonts */
|
| 17 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
|
| 18 |
+
|
| 19 |
+
/* CSS Variables - Professional color system */
|
| 20 |
+
:root {
|
| 21 |
+
/* Primary brand colors */
|
| 22 |
+
--primary-color: #1e3a5f;
|
| 23 |
+
--primary-hover: #2d5a8a;
|
| 24 |
+
--primary-light: #e8f4fd;
|
| 25 |
+
|
| 26 |
+
/* Accent colors */
|
| 27 |
+
--accent-color: #3b82f6;
|
| 28 |
+
--accent-hover: #2563eb;
|
| 29 |
+
--accent-light: #dbeafe;
|
| 30 |
+
|
| 31 |
+
/* Status colors */
|
| 32 |
+
--success-color: #10b981;
|
| 33 |
+
--warning-color: #f59e0b;
|
| 34 |
+
--error-color: #ef4444;
|
| 35 |
+
|
| 36 |
+
/* Neutral colors */
|
| 37 |
+
--bg-primary: #ffffff;
|
| 38 |
+
--bg-secondary: #f8fafc;
|
| 39 |
+
--bg-tertiary: #f1f5f9;
|
| 40 |
+
--text-primary: #1e293b;
|
| 41 |
+
--text-secondary: #475569;
|
| 42 |
+
--text-muted: #94a3b8;
|
| 43 |
+
--border-color: #e2e8f0;
|
| 44 |
+
--border-light: #f1f5f9;
|
| 45 |
+
|
| 46 |
+
/* Shadows */
|
| 47 |
+
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
|
| 48 |
+
--shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -2px rgba(0, 0, 0, 0.1);
|
| 49 |
+
--shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -4px rgba(0, 0, 0, 0.1);
|
| 50 |
+
--shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 8px 10px -6px rgba(0, 0, 0, 0.1);
|
| 51 |
+
|
| 52 |
+
/* Border radius */
|
| 53 |
+
--radius-sm: 6px;
|
| 54 |
+
--radius-md: 8px;
|
| 55 |
+
--radius-lg: 12px;
|
| 56 |
+
--radius-xl: 16px;
|
| 57 |
+
|
| 58 |
+
/* Transitions */
|
| 59 |
+
--transition-fast: 150ms ease;
|
| 60 |
+
--transition-normal: 250ms ease;
|
| 61 |
+
--transition-slow: 350ms ease;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/* Global styles */
|
| 65 |
+
* {
|
| 66 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
|
| 67 |
+
-webkit-font-smoothing: antialiased !important;
|
| 68 |
+
-moz-osx-font-smoothing: grayscale !important;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/* Main container */
|
| 72 |
+
.gradio-container {
|
| 73 |
+
background: linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%) !important;
|
| 74 |
+
min-height: 100vh !important;
|
| 75 |
+
padding: 24px !important;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/* ===== HEADER SECTION ===== */
|
| 79 |
+
.main-header {
|
| 80 |
+
text-align: center !important;
|
| 81 |
+
padding: 48px 32px 40px !important;
|
| 82 |
+
margin-bottom: 32px !important;
|
| 83 |
+
background: linear-gradient(135deg, var(--bg-primary) 0%, var(--bg-secondary) 100%) !important;
|
| 84 |
+
border-radius: var(--radius-xl) !important;
|
| 85 |
+
box-shadow: var(--shadow-md) !important;
|
| 86 |
+
border: 1px solid var(--border-light) !important;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.main-title {
|
| 90 |
+
font-size: 2.75rem !important;
|
| 91 |
+
font-weight: 700 !important;
|
| 92 |
+
color: var(--primary-color) !important;
|
| 93 |
+
margin: 0 0 12px 0 !important;
|
| 94 |
+
letter-spacing: -0.03em !important;
|
| 95 |
+
display: flex !important;
|
| 96 |
+
align-items: center !important;
|
| 97 |
+
justify-content: center !important;
|
| 98 |
+
gap: 14px !important;
|
| 99 |
+
line-height: 1.2 !important;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.title-emoji {
|
| 103 |
+
font-size: 2.5rem !important;
|
| 104 |
+
filter: drop-shadow(0 2px 4px rgba(0,0,0,0.15)) !important;
|
| 105 |
+
transition: transform var(--transition-normal) !important;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.title-emoji:hover {
|
| 109 |
+
transform: scale(1.1) rotate(-5deg) !important;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.main-subtitle {
|
| 113 |
+
font-size: 1.1rem !important;
|
| 114 |
+
color: var(--text-secondary) !important;
|
| 115 |
+
font-weight: 400 !important;
|
| 116 |
+
margin: 0 !important;
|
| 117 |
+
line-height: 1.5 !important;
|
| 118 |
+
max-width: 700px !important;
|
| 119 |
+
margin: 0 auto !important;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
/* ===== CARD SYSTEM ===== */
|
| 123 |
+
.feature-card {
|
| 124 |
+
background: var(--bg-primary) !important;
|
| 125 |
+
border: 1px solid var(--border-color) !important;
|
| 126 |
+
border-radius: var(--radius-lg) !important;
|
| 127 |
+
padding: 24px !important;
|
| 128 |
+
margin-bottom: 20px !important;
|
| 129 |
+
box-shadow: var(--shadow-sm) !important;
|
| 130 |
+
transition: all var(--transition-normal) !important;
|
| 131 |
+
position: relative !important;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
.feature-card:hover {
|
| 135 |
+
border-color: var(--accent-color) !important;
|
| 136 |
+
box-shadow: var(--shadow-lg) !important;
|
| 137 |
+
transform: translateY(-2px) !important;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
.card-title {
|
| 141 |
+
font-size: 1.25rem !important;
|
| 142 |
+
font-weight: 600 !important;
|
| 143 |
+
color: var(--text-primary) !important;
|
| 144 |
+
margin-bottom: 16px !important;
|
| 145 |
+
display: flex !important;
|
| 146 |
+
align-items: center !important;
|
| 147 |
+
gap: 10px !important;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
.section-emoji {
|
| 151 |
+
font-size: 1.2rem !important;
|
| 152 |
+
filter: drop-shadow(0 1px 2px rgba(0,0,0,0.1)) !important;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
/* ===== INPUT COMPONENTS ===== */
|
| 156 |
+
.input-field {
|
| 157 |
+
border: 1px solid var(--border-color) !important;
|
| 158 |
+
border-radius: var(--radius-md) !important;
|
| 159 |
+
background: var(--bg-primary) !important;
|
| 160 |
+
transition: all var(--transition-fast) !important;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
.input-field:focus-within {
|
| 164 |
+
border-color: var(--accent-color) !important;
|
| 165 |
+
box-shadow: 0 0 0 3px var(--accent-light) !important;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
/* ===== BUTTONS ===== */
|
| 169 |
+
.primary-button {
|
| 170 |
+
background: linear-gradient(135deg, var(--accent-color) 0%, var(--accent-hover) 100%) !important;
|
| 171 |
+
color: white !important;
|
| 172 |
+
border: none !important;
|
| 173 |
+
border-radius: var(--radius-md) !important;
|
| 174 |
+
padding: 14px 28px !important;
|
| 175 |
+
font-size: 1rem !important;
|
| 176 |
+
font-weight: 600 !important;
|
| 177 |
+
cursor: pointer !important;
|
| 178 |
+
transition: all var(--transition-normal) !important;
|
| 179 |
+
box-shadow: var(--shadow-md) !important;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.primary-button:hover {
|
| 183 |
+
transform: translateY(-2px) !important;
|
| 184 |
+
box-shadow: var(--shadow-lg) !important;
|
| 185 |
+
filter: brightness(1.05) !important;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
.secondary-button {
|
| 189 |
+
background: var(--bg-primary) !important;
|
| 190 |
+
color: var(--accent-color) !important;
|
| 191 |
+
border: 1.5px solid var(--accent-color) !important;
|
| 192 |
+
border-radius: var(--radius-md) !important;
|
| 193 |
+
padding: 12px 20px !important;
|
| 194 |
+
font-size: 0.95rem !important;
|
| 195 |
+
font-weight: 500 !important;
|
| 196 |
+
cursor: pointer !important;
|
| 197 |
+
transition: all var(--transition-fast) !important;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
.secondary-button:hover {
|
| 201 |
+
background: var(--accent-light) !important;
|
| 202 |
+
transform: translateY(-1px) !important;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
/* ===== RESULTS GALLERY ===== */
|
| 206 |
+
#results-gallery-centered {
|
| 207 |
+
display: flex !important;
|
| 208 |
+
flex-direction: column !important;
|
| 209 |
+
align-items: center !important;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
#results-gallery-centered .gradio-tabs {
|
| 213 |
+
width: 100% !important;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
.result-gallery {
|
| 217 |
+
border-radius: var(--radius-lg) !important;
|
| 218 |
+
overflow: hidden !important;
|
| 219 |
+
border: 1px solid var(--border-color) !important;
|
| 220 |
+
box-shadow: var(--shadow-md) !important;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
.result-gallery img {
|
| 224 |
+
width: 100% !important;
|
| 225 |
+
height: auto !important;
|
| 226 |
+
object-fit: contain !important;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
/* ===== STATUS PANEL ===== */
|
| 230 |
+
.status-panel {
|
| 231 |
+
background: var(--bg-secondary) !important;
|
| 232 |
+
border: 1px solid var(--border-color) !important;
|
| 233 |
+
border-radius: var(--radius-md) !important;
|
| 234 |
+
padding: 12px 16px !important;
|
| 235 |
+
margin: 16px 0 !important;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
.status-ready {
|
| 239 |
+
color: var(--success-color) !important;
|
| 240 |
+
font-weight: 500 !important;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
/* ===== LOADING NOTICE ===== */
|
| 244 |
+
.loading-notice {
|
| 245 |
+
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%) !important;
|
| 246 |
+
border: 1px solid #f59e0b !important;
|
| 247 |
+
border-radius: var(--radius-md) !important;
|
| 248 |
+
padding: 14px 18px !important;
|
| 249 |
+
margin: 16px 0 !important;
|
| 250 |
+
display: flex !important;
|
| 251 |
+
align-items: center !important;
|
| 252 |
+
gap: 12px !important;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
.loading-notice-icon {
|
| 256 |
+
font-size: 1.3rem !important;
|
| 257 |
+
flex-shrink: 0 !important;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
.loading-notice-text {
|
| 261 |
+
color: #92400e !important;
|
| 262 |
+
font-size: 0.9rem !important;
|
| 263 |
+
font-weight: 500 !important;
|
| 264 |
+
line-height: 1.5 !important;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/* ===== QUICK START GUIDE ===== */
|
| 268 |
+
.user-guidance-panel {
|
| 269 |
+
background: var(--bg-secondary) !important;
|
| 270 |
+
border: 1px solid var(--border-color) !important;
|
| 271 |
+
border-radius: var(--radius-md) !important;
|
| 272 |
+
margin: 16px 0 !important;
|
| 273 |
+
overflow: hidden !important;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
.guidance-summary {
|
| 277 |
+
background: var(--bg-primary) !important;
|
| 278 |
+
padding: 12px 16px !important;
|
| 279 |
+
cursor: pointer !important;
|
| 280 |
+
font-weight: 500 !important;
|
| 281 |
+
color: var(--text-primary) !important;
|
| 282 |
+
transition: background var(--transition-fast) !important;
|
| 283 |
+
list-style: none !important;
|
| 284 |
+
display: flex !important;
|
| 285 |
+
align-items: center !important;
|
| 286 |
+
gap: 8px !important;
|
| 287 |
+
border-bottom: 1px solid var(--border-color) !important;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
.guidance-summary:hover {
|
| 291 |
+
background: var(--accent-light) !important;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
.guidance-summary::-webkit-details-marker {
|
| 295 |
+
display: none !important;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
.guidance-content {
|
| 299 |
+
padding: 16px !important;
|
| 300 |
+
color: var(--text-secondary) !important;
|
| 301 |
+
line-height: 1.6 !important;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
.guidance-content p {
|
| 305 |
+
margin: 8px 0 !important;
|
| 306 |
+
font-size: 0.9rem !important;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
.guidance-content strong {
|
| 310 |
+
color: var(--primary-color) !important;
|
| 311 |
+
font-weight: 600 !important;
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
/* ===== FOOTER ===== */
|
| 315 |
+
.app-footer {
|
| 316 |
+
background: var(--bg-primary) !important;
|
| 317 |
+
border: 1px solid var(--border-color) !important;
|
| 318 |
+
border-radius: var(--radius-lg) !important;
|
| 319 |
+
padding: 32px !important;
|
| 320 |
+
margin-top: 32px !important;
|
| 321 |
+
text-align: center !important;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
.footer-powered {
|
| 325 |
+
margin-bottom: 20px !important;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
.footer-powered-title {
|
| 329 |
+
font-size: 0.85rem !important;
|
| 330 |
+
font-weight: 500 !important;
|
| 331 |
+
color: var(--text-muted) !important;
|
| 332 |
+
text-transform: uppercase !important;
|
| 333 |
+
letter-spacing: 0.1em !important;
|
| 334 |
+
margin-bottom: 16px !important;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
.footer-tech-grid {
|
| 338 |
+
display: flex !important;
|
| 339 |
+
flex-wrap: wrap !important;
|
| 340 |
+
justify-content: center !important;
|
| 341 |
+
gap: 12px !important;
|
| 342 |
+
margin-bottom: 24px !important;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
.footer-tech-item {
|
| 346 |
+
background: var(--bg-secondary) !important;
|
| 347 |
+
border: 1px solid var(--border-color) !important;
|
| 348 |
+
border-radius: var(--radius-sm) !important;
|
| 349 |
+
padding: 8px 16px !important;
|
| 350 |
+
font-size: 0.85rem !important;
|
| 351 |
+
font-weight: 500 !important;
|
| 352 |
+
color: var(--text-secondary) !important;
|
| 353 |
+
transition: all var(--transition-fast) !important;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
.footer-tech-item:hover {
|
| 357 |
+
background: var(--accent-light) !important;
|
| 358 |
+
border-color: var(--accent-color) !important;
|
| 359 |
+
color: var(--accent-color) !important;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
.footer-divider {
|
| 363 |
+
height: 1px !important;
|
| 364 |
+
background: var(--border-color) !important;
|
| 365 |
+
margin: 20px auto !important;
|
| 366 |
+
max-width: 400px !important;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
.footer-copyright {
|
| 370 |
+
font-size: 0.85rem !important;
|
| 371 |
+
color: var(--text-muted) !important;
|
| 372 |
+
font-weight: 400 !important;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
.footer-copyright a {
|
| 376 |
+
color: var(--accent-color) !important;
|
| 377 |
+
text-decoration: none !important;
|
| 378 |
+
font-weight: 500 !important;
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
.footer-copyright a:hover {
|
| 382 |
+
text-decoration: underline !important;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
/* ===== TABS STYLING ===== */
|
| 386 |
+
.gradio-tabs {
|
| 387 |
+
border: none !important;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
.gradio-tabs > .tab-nav {
|
| 391 |
+
background: var(--bg-secondary) !important;
|
| 392 |
+
border-radius: var(--radius-md) !important;
|
| 393 |
+
padding: 4px !important;
|
| 394 |
+
gap: 4px !important;
|
| 395 |
+
border: 1px solid var(--border-color) !important;
|
| 396 |
+
margin-bottom: 16px !important;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
.gradio-tabs > .tab-nav > button {
|
| 400 |
+
border-radius: var(--radius-sm) !important;
|
| 401 |
+
padding: 10px 20px !important;
|
| 402 |
+
font-weight: 500 !important;
|
| 403 |
+
font-size: 0.9rem !important;
|
| 404 |
+
transition: all var(--transition-fast) !important;
|
| 405 |
+
border: none !important;
|
| 406 |
+
background: transparent !important;
|
| 407 |
+
color: var(--text-secondary) !important;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
.gradio-tabs > .tab-nav > button.selected {
|
| 411 |
+
background: var(--bg-primary) !important;
|
| 412 |
+
color: var(--accent-color) !important;
|
| 413 |
+
box-shadow: var(--shadow-sm) !important;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
.gradio-tabs > .tab-nav > button:hover:not(.selected) {
|
| 417 |
+
background: var(--bg-primary) !important;
|
| 418 |
+
color: var(--text-primary) !important;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
/* ===== ACCORDION ===== */
|
| 422 |
+
.gradio-accordion {
|
| 423 |
+
border: 1px solid var(--border-color) !important;
|
| 424 |
+
border-radius: var(--radius-md) !important;
|
| 425 |
+
overflow: hidden !important;
|
| 426 |
+
margin-top: 12px !important;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
.gradio-accordion > .label-wrap {
|
| 430 |
+
background: var(--bg-secondary) !important;
|
| 431 |
+
padding: 12px 16px !important;
|
| 432 |
+
font-weight: 500 !important;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
.gradio-accordion > .label-wrap:hover {
|
| 436 |
+
background: var(--bg-tertiary) !important;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
/* ===== RESPONSIVE ===== */
|
| 440 |
+
@media (max-width: 768px) {
|
| 441 |
+
.main-title {
|
| 442 |
+
font-size: 2rem !important;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
.main-subtitle {
|
| 446 |
+
font-size: 1rem !important;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
.footer-tech-grid {
|
| 450 |
+
gap: 8px !important;
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
.footer-tech-item {
|
| 454 |
+
padding: 6px 12px !important;
|
| 455 |
+
font-size: 0.8rem !important;
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
/* ===== EMOJI ENHANCEMENT ===== */
|
| 460 |
+
.emoji-enhanced {
|
| 461 |
+
display: inline-block !important;
|
| 462 |
+
font-style: normal !important;
|
| 463 |
+
filter: drop-shadow(0 1px 2px rgba(0,0,0,0.1)) !important;
|
| 464 |
+
transition: transform var(--transition-fast) !important;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
.emoji-enhanced:hover {
|
| 468 |
+
transform: scale(1.1) !important;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
/* ===== IMAGE DISPLAY FIX ===== */
|
| 472 |
+
.gradio-image {
|
| 473 |
+
min-height: 200px !important;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
.gradio-image img {
|
| 477 |
+
max-height: 500px !important;
|
| 478 |
+
object-fit: contain !important;
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
/* ===== SCENE TEMPLATE DROPDOWN ===== */
|
| 482 |
+
.template-dropdown {
|
| 483 |
+
margin: 8px 0 !important;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
.template-dropdown select,
|
| 487 |
+
.template-dropdown input {
|
| 488 |
+
font-size: 0.95rem !important;
|
| 489 |
+
padding: 10px 14px !important;
|
| 490 |
+
border-radius: var(--radius-md) !important;
|
| 491 |
+
border: 1px solid var(--border-color) !important;
|
| 492 |
+
background: var(--bg-primary) !important;
|
| 493 |
+
transition: all var(--transition-fast) !important;
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
.template-dropdown select:hover,
|
| 497 |
+
.template-dropdown input:hover {
|
| 498 |
+
border-color: var(--accent-color) !important;
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
.template-dropdown select:focus,
|
| 502 |
+
.template-dropdown input:focus {
|
| 503 |
+
border-color: var(--accent-color) !important;
|
| 504 |
+
box-shadow: 0 0 0 3px var(--accent-light) !important;
|
| 505 |
+
outline: none !important;
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
/* Dropdown option styling */
|
| 509 |
+
.template-dropdown option {
|
| 510 |
+
padding: 8px 12px !important;
|
| 511 |
+
font-size: 0.95rem !important;
|
| 512 |
+
}
|
| 513 |
+
"""
|
image_blender.py
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Any, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
logger.setLevel(logging.INFO)
|
| 9 |
+
|
| 10 |
+
class ImageBlender:
|
| 11 |
+
"""
|
| 12 |
+
Advanced image blending with aggressive spill suppression and color replacement
|
| 13 |
+
Completely eliminates yellow edge residue while maintaining sharp edges
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge (reduced to protect more foreground)
|
| 17 |
+
ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization (increased to keep more foreground)
|
| 18 |
+
DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground detection
|
| 19 |
+
FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value above which pixels are strongly protected
|
| 20 |
+
BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background color detection
|
| 21 |
+
|
| 22 |
+
def __init__(self, enable_multi_scale: bool = True):
|
| 23 |
+
"""
|
| 24 |
+
Initialize ImageBlender.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
enable_multi_scale: Whether to enable multi-scale edge refinement (default True)
|
| 28 |
+
"""
|
| 29 |
+
self.enable_multi_scale = enable_multi_scale
|
| 30 |
+
self._debug_info = {}
|
| 31 |
+
self._adaptive_strength_map = None
|
| 32 |
+
|
| 33 |
+
def _erode_mask_edges(
|
| 34 |
+
self,
|
| 35 |
+
mask_array: np.ndarray,
|
| 36 |
+
erosion_pixels: int = 2
|
| 37 |
+
) -> np.ndarray:
|
| 38 |
+
"""
|
| 39 |
+
Erode mask edges to remove contaminated boundary pixels.
|
| 40 |
+
|
| 41 |
+
This removes the outermost pixels of the foreground mask where
|
| 42 |
+
color contamination from the original background is most likely.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
mask_array: Input mask as numpy array (uint8, 0-255)
|
| 46 |
+
erosion_pixels: Number of pixels to erode (default 2)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Eroded mask array (uint8)
|
| 50 |
+
"""
|
| 51 |
+
if erosion_pixels <= 0:
|
| 52 |
+
return mask_array
|
| 53 |
+
|
| 54 |
+
# Use elliptical kernel for natural-looking erosion
|
| 55 |
+
kernel_size = max(2, erosion_pixels)
|
| 56 |
+
kernel = cv2.getStructuringElement(
|
| 57 |
+
cv2.MORPH_ELLIPSE,
|
| 58 |
+
(kernel_size, kernel_size)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Apply erosion
|
| 62 |
+
eroded = cv2.erode(mask_array, kernel, iterations=1)
|
| 63 |
+
|
| 64 |
+
# Slight blur to smooth the eroded edges
|
| 65 |
+
eroded = cv2.GaussianBlur(eroded, (3, 3), 0)
|
| 66 |
+
|
| 67 |
+
logger.debug(f"Mask erosion applied: {erosion_pixels}px, kernel size: {kernel_size}")
|
| 68 |
+
return eroded
|
| 69 |
+
|
| 70 |
+
def _binarize_edge_alpha(
|
| 71 |
+
self,
|
| 72 |
+
alpha: np.ndarray,
|
| 73 |
+
mask_array: np.ndarray,
|
| 74 |
+
orig_array: np.ndarray,
|
| 75 |
+
threshold: float = 0.45
|
| 76 |
+
) -> np.ndarray:
|
| 77 |
+
"""
|
| 78 |
+
Binarize semi-transparent edge pixels to eliminate color bleeding.
|
| 79 |
+
|
| 80 |
+
Semi-transparent pixels at edges cause visible contamination because
|
| 81 |
+
they blend the original (potentially dark) foreground with the new
|
| 82 |
+
background. This method forces edge pixels to be either fully opaque
|
| 83 |
+
or fully transparent.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
alpha: Current alpha channel (float32, 0.0-1.0)
|
| 87 |
+
mask_array: Original mask array (uint8, 0-255)
|
| 88 |
+
orig_array: Original foreground image array (uint8, RGB)
|
| 89 |
+
threshold: Alpha threshold for binarization decision (default 0.45)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Modified alpha array with binarized edges (float32)
|
| 93 |
+
"""
|
| 94 |
+
# Identify semi-transparent edge zone (not fully opaque, not fully transparent)
|
| 95 |
+
edge_zone = (alpha > 0.05) & (alpha < 0.95)
|
| 96 |
+
|
| 97 |
+
if not np.any(edge_zone):
|
| 98 |
+
return alpha
|
| 99 |
+
|
| 100 |
+
# Calculate local foreground luminance for adaptive thresholding
|
| 101 |
+
gray = np.mean(orig_array, axis=2)
|
| 102 |
+
|
| 103 |
+
# For dark foreground pixels, use slightly higher threshold
|
| 104 |
+
# to preserve more of the dark subject
|
| 105 |
+
is_dark = gray < self.DARK_LUMINANCE_THRESHOLD
|
| 106 |
+
|
| 107 |
+
# Create adaptive threshold map
|
| 108 |
+
adaptive_threshold = np.full_like(alpha, threshold)
|
| 109 |
+
adaptive_threshold[is_dark] = threshold + 0.1 # Keep more dark pixels
|
| 110 |
+
|
| 111 |
+
# Binarize: above threshold -> opaque, below -> transparent
|
| 112 |
+
alpha_binarized = alpha.copy()
|
| 113 |
+
|
| 114 |
+
# Pixels above threshold become fully opaque
|
| 115 |
+
make_opaque = edge_zone & (alpha > adaptive_threshold)
|
| 116 |
+
alpha_binarized[make_opaque] = 1.0
|
| 117 |
+
|
| 118 |
+
# Pixels below threshold become fully transparent
|
| 119 |
+
make_transparent = edge_zone & (alpha <= adaptive_threshold)
|
| 120 |
+
alpha_binarized[make_transparent] = 0.0
|
| 121 |
+
|
| 122 |
+
# Log statistics
|
| 123 |
+
num_opaque = np.sum(make_opaque)
|
| 124 |
+
num_transparent = np.sum(make_transparent)
|
| 125 |
+
logger.info(f"Edge binarization: {num_opaque} pixels -> opaque, {num_transparent} pixels -> transparent")
|
| 126 |
+
|
| 127 |
+
return alpha_binarized
|
| 128 |
+
|
| 129 |
+
def _apply_edge_cleanup(
|
| 130 |
+
self,
|
| 131 |
+
result_array: np.ndarray,
|
| 132 |
+
bg_array: np.ndarray,
|
| 133 |
+
alpha: np.ndarray,
|
| 134 |
+
cleanup_width: int = 2
|
| 135 |
+
) -> np.ndarray:
|
| 136 |
+
"""
|
| 137 |
+
Final cleanup pass to remove any remaining edge artifacts.
|
| 138 |
+
|
| 139 |
+
Detects remaining semi-transparent edges and replaces them with
|
| 140 |
+
either pure foreground or pure background colors.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
result_array: Current blended result (uint8, RGB)
|
| 144 |
+
bg_array: Background image array (uint8, RGB)
|
| 145 |
+
alpha: Final alpha channel (float32, 0.0-1.0)
|
| 146 |
+
cleanup_width: Width of edge zone to clean (default 2)
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Cleaned result array (uint8)
|
| 150 |
+
"""
|
| 151 |
+
# Find edge pixels that might still have artifacts
|
| 152 |
+
# These are pixels with alpha close to but not exactly 0 or 1
|
| 153 |
+
residual_edge = (alpha > 0.01) & (alpha < 0.99) & (alpha != 0.0) & (alpha != 1.0)
|
| 154 |
+
|
| 155 |
+
if not np.any(residual_edge):
|
| 156 |
+
return result_array
|
| 157 |
+
|
| 158 |
+
result_cleaned = result_array.copy()
|
| 159 |
+
|
| 160 |
+
# For residual edge pixels, snap to nearest pure state
|
| 161 |
+
snap_to_bg = residual_edge & (alpha < 0.5)
|
| 162 |
+
snap_to_fg = residual_edge & (alpha >= 0.5)
|
| 163 |
+
|
| 164 |
+
# Replace with background
|
| 165 |
+
result_cleaned[snap_to_bg] = bg_array[snap_to_bg]
|
| 166 |
+
|
| 167 |
+
# For foreground, keep original but ensure no blending artifacts
|
| 168 |
+
# (already handled by the blend, so no action needed for snap_to_fg)
|
| 169 |
+
|
| 170 |
+
num_cleaned = np.sum(residual_edge)
|
| 171 |
+
if num_cleaned > 0:
|
| 172 |
+
logger.debug(f"Edge cleanup: {num_cleaned} residual pixels cleaned")
|
| 173 |
+
|
| 174 |
+
return result_cleaned
|
| 175 |
+
|
| 176 |
+
def _remove_background_color_contamination(
|
| 177 |
+
self,
|
| 178 |
+
image_array: np.ndarray,
|
| 179 |
+
mask_array: np.ndarray,
|
| 180 |
+
orig_bg_color_lab: np.ndarray,
|
| 181 |
+
tolerance: float = 30.0
|
| 182 |
+
) -> np.ndarray:
|
| 183 |
+
"""
|
| 184 |
+
Remove original background color contamination from foreground pixels.
|
| 185 |
+
|
| 186 |
+
Scans the foreground area for pixels that match the original background
|
| 187 |
+
color and replaces them with nearby clean foreground colors.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
image_array: Foreground image array (uint8, RGB)
|
| 191 |
+
mask_array: Mask array (uint8, 0-255)
|
| 192 |
+
orig_bg_color_lab: Original background color in Lab space
|
| 193 |
+
tolerance: DeltaE tolerance for detecting contaminated pixels
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Cleaned image array (uint8)
|
| 197 |
+
"""
|
| 198 |
+
# Convert to Lab for color comparison
|
| 199 |
+
image_lab = cv2.cvtColor(image_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 200 |
+
|
| 201 |
+
# Only process foreground pixels (mask > 50)
|
| 202 |
+
foreground_mask = mask_array > 50
|
| 203 |
+
|
| 204 |
+
if not np.any(foreground_mask):
|
| 205 |
+
return image_array
|
| 206 |
+
|
| 207 |
+
# Calculate deltaE from original background color for all pixels
|
| 208 |
+
delta_l = image_lab[:, :, 0] - orig_bg_color_lab[0]
|
| 209 |
+
delta_a = image_lab[:, :, 1] - orig_bg_color_lab[1]
|
| 210 |
+
delta_b = image_lab[:, :, 2] - orig_bg_color_lab[2]
|
| 211 |
+
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
|
| 212 |
+
|
| 213 |
+
# Find contaminated pixels: in foreground but color similar to original background
|
| 214 |
+
contaminated = foreground_mask & (delta_e < tolerance)
|
| 215 |
+
|
| 216 |
+
if not np.any(contaminated):
|
| 217 |
+
logger.debug("No background color contamination detected in foreground")
|
| 218 |
+
return image_array
|
| 219 |
+
|
| 220 |
+
num_contaminated = np.sum(contaminated)
|
| 221 |
+
logger.info(f"Found {num_contaminated} pixels with background color contamination")
|
| 222 |
+
|
| 223 |
+
# Create output array
|
| 224 |
+
result = image_array.copy()
|
| 225 |
+
|
| 226 |
+
# For contaminated pixels, use inpainting to replace with surrounding colors
|
| 227 |
+
inpaint_mask = contaminated.astype(np.uint8) * 255
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
# Use inpainting to fill contaminated areas with surrounding foreground colors
|
| 231 |
+
result = cv2.inpaint(result, inpaint_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
|
| 232 |
+
logger.info(f"Inpainted {num_contaminated} contaminated pixels")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logger.warning(f"Inpainting failed: {e}, using median filter fallback")
|
| 235 |
+
# Fallback: apply median filter to contaminated areas
|
| 236 |
+
median_filtered = cv2.medianBlur(image_array, 5)
|
| 237 |
+
result[contaminated] = median_filtered[contaminated]
|
| 238 |
+
|
| 239 |
+
return result
|
| 240 |
+
|
| 241 |
+
def _protect_foreground_core(
|
| 242 |
+
self,
|
| 243 |
+
result_array: np.ndarray,
|
| 244 |
+
orig_array: np.ndarray,
|
| 245 |
+
mask_array: np.ndarray,
|
| 246 |
+
protection_threshold: int = 140
|
| 247 |
+
) -> np.ndarray:
|
| 248 |
+
"""
|
| 249 |
+
Strongly protect core foreground pixels from any background influence.
|
| 250 |
+
|
| 251 |
+
For pixels with high mask confidence, directly use the original foreground
|
| 252 |
+
color without any blending, ensuring faces and bodies are not affected.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
result_array: Current blended result (uint8, RGB)
|
| 256 |
+
orig_array: Original foreground image (uint8, RGB)
|
| 257 |
+
mask_array: Mask array (uint8, 0-255)
|
| 258 |
+
protection_threshold: Mask value above which pixels are fully protected
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Protected result array (uint8)
|
| 262 |
+
"""
|
| 263 |
+
# Identify strongly protected foreground pixels
|
| 264 |
+
strong_foreground = mask_array >= protection_threshold
|
| 265 |
+
|
| 266 |
+
if not np.any(strong_foreground):
|
| 267 |
+
return result_array
|
| 268 |
+
|
| 269 |
+
# For these pixels, use original foreground color directly
|
| 270 |
+
result_protected = result_array.copy()
|
| 271 |
+
result_protected[strong_foreground] = orig_array[strong_foreground]
|
| 272 |
+
|
| 273 |
+
num_protected = np.sum(strong_foreground)
|
| 274 |
+
logger.info(f"Protected {num_protected} core foreground pixels from background influence")
|
| 275 |
+
|
| 276 |
+
return result_protected
|
| 277 |
+
|
| 278 |
+
def multi_scale_edge_refinement(
|
| 279 |
+
self,
|
| 280 |
+
original_image: Image.Image,
|
| 281 |
+
background_image: Image.Image,
|
| 282 |
+
mask: Image.Image
|
| 283 |
+
) -> Image.Image:
|
| 284 |
+
"""
|
| 285 |
+
Multi-scale edge refinement for better edge quality.
|
| 286 |
+
Uses image pyramid to handle edges at different scales.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
original_image: Foreground PIL Image
|
| 290 |
+
background_image: Background PIL Image
|
| 291 |
+
mask: Current mask PIL Image
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Refined mask PIL Image
|
| 295 |
+
"""
|
| 296 |
+
logger.info("🔍 Starting multi-scale edge refinement...")
|
| 297 |
+
|
| 298 |
+
try:
|
| 299 |
+
# Convert to numpy arrays
|
| 300 |
+
orig_array = np.array(original_image.convert('RGB'))
|
| 301 |
+
mask_array = np.array(mask).astype(np.float32)
|
| 302 |
+
height, width = mask_array.shape
|
| 303 |
+
|
| 304 |
+
# Define scales for pyramid
|
| 305 |
+
scales = [1.0, 0.5, 0.25] # Original, half, quarter
|
| 306 |
+
scale_masks = []
|
| 307 |
+
scale_complexities = []
|
| 308 |
+
|
| 309 |
+
# Convert to grayscale for edge detection
|
| 310 |
+
gray = cv2.cvtColor(orig_array, cv2.COLOR_RGB2GRAY)
|
| 311 |
+
|
| 312 |
+
for scale in scales:
|
| 313 |
+
if scale == 1.0:
|
| 314 |
+
scaled_gray = gray
|
| 315 |
+
scaled_mask = mask_array
|
| 316 |
+
else:
|
| 317 |
+
new_h = int(height * scale)
|
| 318 |
+
new_w = int(width * scale)
|
| 319 |
+
scaled_gray = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 320 |
+
scaled_mask = cv2.resize(mask_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 321 |
+
|
| 322 |
+
# Compute local complexity using gradient standard deviation
|
| 323 |
+
sobel_x = cv2.Sobel(scaled_gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 324 |
+
sobel_y = cv2.Sobel(scaled_gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 325 |
+
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
|
| 326 |
+
|
| 327 |
+
# Calculate local complexity in 5x5 regions
|
| 328 |
+
kernel_size = 5
|
| 329 |
+
complexity = cv2.blur(gradient_mag, (kernel_size, kernel_size))
|
| 330 |
+
|
| 331 |
+
# Resize back to original size
|
| 332 |
+
if scale != 1.0:
|
| 333 |
+
scaled_mask = cv2.resize(scaled_mask, (width, height), interpolation=cv2.INTER_LANCZOS4)
|
| 334 |
+
complexity = cv2.resize(complexity, (width, height), interpolation=cv2.INTER_LANCZOS4)
|
| 335 |
+
|
| 336 |
+
scale_masks.append(scaled_mask)
|
| 337 |
+
scale_complexities.append(complexity)
|
| 338 |
+
|
| 339 |
+
# Compute weights based on complexity
|
| 340 |
+
# High complexity -> use high resolution mask
|
| 341 |
+
# Low complexity -> use low resolution mask (smoother)
|
| 342 |
+
weights = np.zeros((len(scales), height, width), dtype=np.float32)
|
| 343 |
+
|
| 344 |
+
# Normalize complexities
|
| 345 |
+
max_complexity = max(c.max() for c in scale_complexities) + 1e-6
|
| 346 |
+
normalized_complexities = [c / max_complexity for c in scale_complexities]
|
| 347 |
+
|
| 348 |
+
# Weight assignment: higher complexity at each scale means that scale is more reliable
|
| 349 |
+
for i, complexity in enumerate(normalized_complexities):
|
| 350 |
+
if i == 0: # High resolution - prefer for high complexity regions
|
| 351 |
+
weights[i] = complexity
|
| 352 |
+
elif i == 1: # Medium resolution - moderate complexity
|
| 353 |
+
weights[i] = 0.5 * (1 - complexity) + 0.5 * complexity * 0.5
|
| 354 |
+
else: # Low resolution - prefer for low complexity regions
|
| 355 |
+
weights[i] = 1 - complexity
|
| 356 |
+
|
| 357 |
+
# Normalize weights so they sum to 1 at each pixel
|
| 358 |
+
weight_sum = weights.sum(axis=0, keepdims=True) + 1e-6
|
| 359 |
+
weights = weights / weight_sum
|
| 360 |
+
|
| 361 |
+
# Weighted blend of masks from different scales
|
| 362 |
+
refined_mask = np.zeros((height, width), dtype=np.float32)
|
| 363 |
+
for i, mask_i in enumerate(scale_masks):
|
| 364 |
+
refined_mask += weights[i] * mask_i
|
| 365 |
+
|
| 366 |
+
# Clip and convert to uint8
|
| 367 |
+
refined_mask = np.clip(refined_mask, 0, 255).astype(np.uint8)
|
| 368 |
+
|
| 369 |
+
logger.info("✅ Multi-scale edge refinement completed")
|
| 370 |
+
return Image.fromarray(refined_mask, mode='L')
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"❌ Multi-scale refinement failed: {e}, using original mask")
|
| 374 |
+
return mask
|
| 375 |
+
|
| 376 |
+
def simple_blend_images(
|
| 377 |
+
self,
|
| 378 |
+
original_image: Image.Image,
|
| 379 |
+
background_image: Image.Image,
|
| 380 |
+
combination_mask: Image.Image,
|
| 381 |
+
use_multi_scale: Optional[bool] = None
|
| 382 |
+
) -> Image.Image:
|
| 383 |
+
"""
|
| 384 |
+
Aggressive spill suppression + color replacement: completely eliminate yellow edge residue, maintain sharp edges
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
original_image: Foreground PIL Image
|
| 388 |
+
background_image: Background PIL Image
|
| 389 |
+
combination_mask: Mask PIL Image (L mode)
|
| 390 |
+
use_multi_scale: Override for multi-scale refinement (None = use class default)
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
Blended PIL Image
|
| 394 |
+
"""
|
| 395 |
+
logger.info("🎨 Starting advanced image blending process...")
|
| 396 |
+
|
| 397 |
+
# Apply multi-scale edge refinement if enabled
|
| 398 |
+
should_use_multi_scale = use_multi_scale if use_multi_scale is not None else self.enable_multi_scale
|
| 399 |
+
if should_use_multi_scale:
|
| 400 |
+
combination_mask = self.multi_scale_edge_refinement(
|
| 401 |
+
original_image, background_image, combination_mask
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Convert to numpy arrays
|
| 405 |
+
orig_array = np.array(original_image, dtype=np.uint8)
|
| 406 |
+
bg_array = np.array(background_image, dtype=np.uint8)
|
| 407 |
+
mask_array = np.array(combination_mask, dtype=np.uint8)
|
| 408 |
+
|
| 409 |
+
logger.info(f"📊 Image dimensions - Original: {orig_array.shape}, Background: {bg_array.shape}, Mask: {mask_array.shape}")
|
| 410 |
+
logger.info(f"📊 Mask statistics (before erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
|
| 411 |
+
|
| 412 |
+
# === NEW: Apply mask erosion to remove contaminated edge pixels ===
|
| 413 |
+
mask_array = self._erode_mask_edges(mask_array, self.EDGE_EROSION_PIXELS)
|
| 414 |
+
logger.info(f"📊 Mask statistics (after erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
|
| 415 |
+
|
| 416 |
+
# Enhanced parameters for better spill suppression
|
| 417 |
+
RING_WIDTH_PX = 4 # Increased ring width for better coverage
|
| 418 |
+
SPILL_STRENGTH = 0.85 # Stronger spill suppression
|
| 419 |
+
L_MATCH_STRENGTH = 0.65 # Stronger luminance matching
|
| 420 |
+
DELTAE_THRESHOLD = 18 # More aggressive contamination detection
|
| 421 |
+
HARD_EDGE_PROTECT = True # Black edge protection
|
| 422 |
+
INPAINT_FALLBACK = True # inpaint fallback repair
|
| 423 |
+
MULTI_PASS_CORRECTION = True # Enable multi-pass correction
|
| 424 |
+
|
| 425 |
+
# Estimate original background color and foreground representative color ===
|
| 426 |
+
height, width = orig_array.shape[:2]
|
| 427 |
+
|
| 428 |
+
# Take 15px from each side to estimate original background color
|
| 429 |
+
edge_width = 15
|
| 430 |
+
border_pixels = []
|
| 431 |
+
|
| 432 |
+
# Collect border pixels (excluding foreground areas)
|
| 433 |
+
border_mask = np.zeros((height, width), dtype=bool)
|
| 434 |
+
border_mask[:edge_width, :] = True # Top edge
|
| 435 |
+
border_mask[-edge_width:, :] = True # Bottom edge
|
| 436 |
+
border_mask[:, :edge_width] = True # Left edge
|
| 437 |
+
border_mask[:, -edge_width:] = True # Right edge
|
| 438 |
+
|
| 439 |
+
# Exclude foreground areas
|
| 440 |
+
fg_binary = mask_array > 50
|
| 441 |
+
border_mask = border_mask & (~fg_binary)
|
| 442 |
+
|
| 443 |
+
if np.any(border_mask):
|
| 444 |
+
border_pixels = orig_array[border_mask].reshape(-1, 3)
|
| 445 |
+
|
| 446 |
+
# Simplified background color estimation (no sklearn dependency)
|
| 447 |
+
try:
|
| 448 |
+
if len(border_pixels) > 100:
|
| 449 |
+
# Use histogram to find mode colors
|
| 450 |
+
# Quantize RGB to coarser grid to find main colors
|
| 451 |
+
quantized = (border_pixels // 32) * 32 # 8-level quantization
|
| 452 |
+
|
| 453 |
+
# Find most frequent color
|
| 454 |
+
unique_colors, counts = np.unique(quantized.reshape(-1, quantized.shape[-1]),
|
| 455 |
+
axis=0, return_counts=True)
|
| 456 |
+
most_common_idx = np.argmax(counts)
|
| 457 |
+
orig_bg_color_rgb = unique_colors[most_common_idx].astype(np.uint8)
|
| 458 |
+
else:
|
| 459 |
+
orig_bg_color_rgb = np.median(border_pixels, axis=0).astype(np.uint8)
|
| 460 |
+
except:
|
| 461 |
+
# Fallback: use four corners average
|
| 462 |
+
corners = np.array([orig_array[0,0], orig_array[0,-1],
|
| 463 |
+
orig_array[-1,0], orig_array[-1,-1]])
|
| 464 |
+
orig_bg_color_rgb = np.mean(corners, axis=0).astype(np.uint8)
|
| 465 |
+
else:
|
| 466 |
+
orig_bg_color_rgb = np.array([200, 180, 120], dtype=np.uint8) # Default yellow
|
| 467 |
+
|
| 468 |
+
# Convert to Lab space
|
| 469 |
+
orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 470 |
+
logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
|
| 471 |
+
|
| 472 |
+
# Remove original background color contamination from foreground
|
| 473 |
+
orig_array = self._remove_background_color_contamination(
|
| 474 |
+
orig_array,
|
| 475 |
+
mask_array,
|
| 476 |
+
orig_bg_color_lab,
|
| 477 |
+
tolerance=self.BACKGROUND_COLOR_TOLERANCE
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Redefine trimap, optimized for cartoon characters
|
| 481 |
+
try:
|
| 482 |
+
kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 483 |
+
|
| 484 |
+
# FG_CORE: Reduce erosion iterations from 2 to 1 to avoid losing thin limbs
|
| 485 |
+
mask_eroded_once = cv2.erode(mask_array, kernel_3x3, iterations=1)
|
| 486 |
+
fg_core = mask_eroded_once > 127 # Adjustable parameter: erosion iterations
|
| 487 |
+
|
| 488 |
+
# RING: Use morphological gradient to redefine, ensuring only thin edge band
|
| 489 |
+
mask_dilated = cv2.dilate(mask_array, kernel_3x3, iterations=1)
|
| 490 |
+
mask_eroded = cv2.erode(mask_array, kernel_3x3, iterations=1)
|
| 491 |
+
|
| 492 |
+
# Ensure consistent data types to avoid overflow
|
| 493 |
+
morphological_gradient = cv2.subtract(mask_dilated, mask_eroded)
|
| 494 |
+
ring_zone = morphological_gradient > 0 # Areas with morphological gradient > 0 are edge bands
|
| 495 |
+
|
| 496 |
+
# BG: background area
|
| 497 |
+
bg_zone = mask_array < 30
|
| 498 |
+
|
| 499 |
+
logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
|
| 500 |
+
|
| 501 |
+
except Exception as e:
|
| 502 |
+
import traceback
|
| 503 |
+
logger.error(f"❌ Trimap definition failed: {e}")
|
| 504 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 505 |
+
print(f"❌ TRIMAP ERROR: {e}")
|
| 506 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 507 |
+
# Fallback to simple definition
|
| 508 |
+
fg_core = mask_array > 200
|
| 509 |
+
ring_zone = (mask_array > 50) & (mask_array <= 200)
|
| 510 |
+
bg_zone = mask_array <= 50
|
| 511 |
+
|
| 512 |
+
# Foreground representative color: estimated from FG_CORE
|
| 513 |
+
if np.any(fg_core):
|
| 514 |
+
fg_pixels = orig_array[fg_core].reshape(-1, 3)
|
| 515 |
+
fg_rep_color_rgb = np.median(fg_pixels, axis=0).astype(np.uint8)
|
| 516 |
+
else:
|
| 517 |
+
fg_rep_color_rgb = np.array([80, 60, 40], dtype=np.uint8) # Default dark
|
| 518 |
+
|
| 519 |
+
fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 520 |
+
|
| 521 |
+
# Edge band spill suppression and repair
|
| 522 |
+
if np.any(ring_zone):
|
| 523 |
+
# Convert to Lab space
|
| 524 |
+
orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 525 |
+
orig_array_working = orig_array.copy().astype(np.float32)
|
| 526 |
+
|
| 527 |
+
# ΔE detect contaminated pixels
|
| 528 |
+
ring_pixels_lab = orig_lab[ring_zone]
|
| 529 |
+
|
| 530 |
+
# Calculate ΔE with original background color (simplified version)
|
| 531 |
+
delta_l = ring_pixels_lab[:, 0] - orig_bg_color_lab[0]
|
| 532 |
+
delta_a = ring_pixels_lab[:, 1] - orig_bg_color_lab[1]
|
| 533 |
+
delta_b = ring_pixels_lab[:, 2] - orig_bg_color_lab[2]
|
| 534 |
+
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
|
| 535 |
+
|
| 536 |
+
# Contaminated pixel mask
|
| 537 |
+
contaminated_mask = delta_e < DELTAE_THRESHOLD
|
| 538 |
+
|
| 539 |
+
if np.any(contaminated_mask):
|
| 540 |
+
# Calculate adaptive strength based on delta_e for each pixel
|
| 541 |
+
# Pixels closer to background color get stronger correction
|
| 542 |
+
contaminated_delta_e = delta_e[contaminated_mask]
|
| 543 |
+
|
| 544 |
+
# Adaptive strength formula: inverse relationship with delta_e
|
| 545 |
+
# Pixels very close to bg color (low delta_e) -> strong correction
|
| 546 |
+
# Pixels further from bg color (high delta_e) -> lighter correction
|
| 547 |
+
adaptive_strength = SPILL_STRENGTH * np.maximum(
|
| 548 |
+
0.0,
|
| 549 |
+
1.0 - (contaminated_delta_e / DELTAE_THRESHOLD)
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# Clamp adaptive strength to reasonable range (30% - 100% of base strength)
|
| 553 |
+
min_strength = SPILL_STRENGTH * 0.3
|
| 554 |
+
adaptive_strength = np.clip(adaptive_strength, min_strength, SPILL_STRENGTH)
|
| 555 |
+
|
| 556 |
+
# Store for debug visualization
|
| 557 |
+
self._adaptive_strength_map = np.zeros_like(delta_e)
|
| 558 |
+
self._adaptive_strength_map[contaminated_mask] = adaptive_strength
|
| 559 |
+
|
| 560 |
+
logger.info(f"📊 Adaptive strength stats - Mean: {adaptive_strength.mean():.3f}, Min: {adaptive_strength.min():.3f}, Max: {adaptive_strength.max():.3f}")
|
| 561 |
+
|
| 562 |
+
# Chroma vector deprojection
|
| 563 |
+
bg_chroma = np.array([orig_bg_color_lab[1], orig_bg_color_lab[2]])
|
| 564 |
+
bg_chroma_norm = bg_chroma / (np.linalg.norm(bg_chroma) + 1e-6)
|
| 565 |
+
|
| 566 |
+
# Color correction for contaminated pixels
|
| 567 |
+
contaminated_pixels = ring_pixels_lab[contaminated_mask]
|
| 568 |
+
|
| 569 |
+
# Remove background chroma component with adaptive strength (per-pixel)
|
| 570 |
+
pixel_chroma = contaminated_pixels[:, 1:3] # a, b channels
|
| 571 |
+
projection = np.dot(pixel_chroma, bg_chroma_norm)[:, np.newaxis] * bg_chroma_norm
|
| 572 |
+
|
| 573 |
+
# Apply adaptive strength per pixel
|
| 574 |
+
adaptive_strength_2d = adaptive_strength[:, np.newaxis]
|
| 575 |
+
corrected_chroma = pixel_chroma - projection * adaptive_strength_2d
|
| 576 |
+
|
| 577 |
+
# Converge toward foreground representative color with adaptive strength
|
| 578 |
+
convergence_factor = adaptive_strength_2d * 0.6
|
| 579 |
+
corrected_chroma = (corrected_chroma * (1 - convergence_factor) +
|
| 580 |
+
fg_rep_color_lab[1:3] * convergence_factor)
|
| 581 |
+
|
| 582 |
+
# Adaptive luminance matching
|
| 583 |
+
adaptive_l_strength = adaptive_strength * (L_MATCH_STRENGTH / SPILL_STRENGTH)
|
| 584 |
+
corrected_l = (contaminated_pixels[:, 0] * (1 - adaptive_l_strength) +
|
| 585 |
+
fg_rep_color_lab[0] * adaptive_l_strength)
|
| 586 |
+
|
| 587 |
+
# Update Lab values
|
| 588 |
+
ring_pixels_lab[contaminated_mask, 0] = corrected_l
|
| 589 |
+
ring_pixels_lab[contaminated_mask, 1:3] = corrected_chroma
|
| 590 |
+
|
| 591 |
+
# Write back to original image
|
| 592 |
+
orig_lab[ring_zone] = ring_pixels_lab
|
| 593 |
+
|
| 594 |
+
# Dark edge protection
|
| 595 |
+
if HARD_EDGE_PROTECT:
|
| 596 |
+
gray = np.mean(orig_array, axis=2)
|
| 597 |
+
# Detect dark and high gradient areas
|
| 598 |
+
sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 599 |
+
sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 600 |
+
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
|
| 601 |
+
|
| 602 |
+
dark_edge_zone = ring_zone & (gray < 60) & (gradient_mag > 20)
|
| 603 |
+
# Protect these areas from excessive modification, copy directly from original
|
| 604 |
+
if np.any(dark_edge_zone):
|
| 605 |
+
orig_lab[dark_edge_zone] = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB)[dark_edge_zone]
|
| 606 |
+
|
| 607 |
+
# Multi-pass correction for stubborn spill
|
| 608 |
+
if MULTI_PASS_CORRECTION:
|
| 609 |
+
# Second pass for remaining contamination
|
| 610 |
+
ring_pixels_lab_pass2 = orig_lab[ring_zone]
|
| 611 |
+
delta_l_pass2 = ring_pixels_lab_pass2[:, 0] - orig_bg_color_lab[0]
|
| 612 |
+
delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
|
| 613 |
+
delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
|
| 614 |
+
delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
|
| 615 |
+
|
| 616 |
+
still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
|
| 617 |
+
|
| 618 |
+
if np.any(still_contaminated):
|
| 619 |
+
# Apply stronger correction to remaining contaminated pixels
|
| 620 |
+
remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
|
| 621 |
+
|
| 622 |
+
# More aggressive chroma neutralization
|
| 623 |
+
remaining_chroma = remaining_pixels[:, 1:3]
|
| 624 |
+
neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
|
| 625 |
+
|
| 626 |
+
# Stronger luminance matching
|
| 627 |
+
neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
|
| 628 |
+
|
| 629 |
+
ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
|
| 630 |
+
ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
|
| 631 |
+
orig_lab[ring_zone] = ring_pixels_lab_pass2
|
| 632 |
+
|
| 633 |
+
# Convert back to RGB
|
| 634 |
+
orig_lab_clipped = np.clip(orig_lab, 0, 255).astype(np.uint8)
|
| 635 |
+
orig_array_corrected = cv2.cvtColor(orig_lab_clipped, cv2.COLOR_LAB2RGB)
|
| 636 |
+
|
| 637 |
+
# inpaint fallback repair
|
| 638 |
+
if INPAINT_FALLBACK:
|
| 639 |
+
# inpaint still contaminated outermost pixels
|
| 640 |
+
final_contaminated = ring_zone.copy()
|
| 641 |
+
|
| 642 |
+
# Check if there's still contamination after repair
|
| 643 |
+
final_lab = cv2.cvtColor(orig_array_corrected, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 644 |
+
final_ring_lab = final_lab[ring_zone]
|
| 645 |
+
final_delta_l = final_ring_lab[:, 0] - orig_bg_color_lab[0]
|
| 646 |
+
final_delta_a = final_ring_lab[:, 1] - orig_bg_color_lab[1]
|
| 647 |
+
final_delta_b = final_ring_lab[:, 2] - orig_bg_color_lab[2]
|
| 648 |
+
final_delta_e = np.sqrt(final_delta_l**2 + final_delta_a**2 + final_delta_b**2)
|
| 649 |
+
|
| 650 |
+
still_contaminated = final_delta_e < (DELTAE_THRESHOLD * 0.5)
|
| 651 |
+
if np.any(still_contaminated):
|
| 652 |
+
# Create inpaint mask
|
| 653 |
+
inpaint_mask = np.zeros((height, width), dtype=np.uint8)
|
| 654 |
+
ring_coords = np.where(ring_zone)
|
| 655 |
+
inpaint_coords = (ring_coords[0][still_contaminated], ring_coords[1][still_contaminated])
|
| 656 |
+
inpaint_mask[inpaint_coords] = 255
|
| 657 |
+
|
| 658 |
+
# Execute inpaint
|
| 659 |
+
try:
|
| 660 |
+
orig_array_corrected = cv2.inpaint(orig_array_corrected, inpaint_mask, 3, cv2.INPAINT_TELEA)
|
| 661 |
+
except:
|
| 662 |
+
# Fallback: directly cover with foreground representative color
|
| 663 |
+
orig_array_corrected[inpaint_coords] = fg_rep_color_rgb
|
| 664 |
+
|
| 665 |
+
orig_array = orig_array_corrected
|
| 666 |
+
|
| 667 |
+
# === Linear space blending (keep original logic) ===
|
| 668 |
+
def srgb_to_linear(img):
|
| 669 |
+
img_f = img.astype(np.float32) / 255.0
|
| 670 |
+
return np.where(img_f <= 0.04045, img_f / 12.92, np.power((img_f + 0.055) / 1.055, 2.4))
|
| 671 |
+
|
| 672 |
+
def linear_to_srgb(img):
|
| 673 |
+
img_clipped = np.clip(img, 0, 1)
|
| 674 |
+
return np.where(img_clipped <= 0.0031308,
|
| 675 |
+
12.92 * img_clipped,
|
| 676 |
+
1.055 * np.power(img_clipped, 1/2.4) - 0.055)
|
| 677 |
+
|
| 678 |
+
orig_linear = srgb_to_linear(orig_array)
|
| 679 |
+
bg_linear = srgb_to_linear(bg_array)
|
| 680 |
+
|
| 681 |
+
# === Cartoon-optimized Alpha calculation ===
|
| 682 |
+
alpha = mask_array.astype(np.float32) / 255.0
|
| 683 |
+
|
| 684 |
+
# Core foreground region - fully opaque
|
| 685 |
+
alpha[fg_core] = 1.0
|
| 686 |
+
|
| 687 |
+
# Background region - fully transparent
|
| 688 |
+
alpha[bg_zone] = 0.0
|
| 689 |
+
|
| 690 |
+
# [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
|
| 691 |
+
high_confidence_pixels = mask_array >= 160
|
| 692 |
+
alpha[high_confidence_pixels] = 1.0
|
| 693 |
+
logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
|
| 694 |
+
|
| 695 |
+
# Ring area can be dehaloed, but doesn't affect already set high confidence pixels
|
| 696 |
+
ring_without_high_conf = ring_zone & (~high_confidence_pixels)
|
| 697 |
+
alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
|
| 698 |
+
|
| 699 |
+
# Retain existing black outline/strong edge protection
|
| 700 |
+
orig_gray = np.mean(orig_array, axis=2)
|
| 701 |
+
|
| 702 |
+
# Detect strong edge areas
|
| 703 |
+
sobel_x = cv2.Sobel(orig_gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 704 |
+
sobel_y = cv2.Sobel(orig_gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 705 |
+
gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
|
| 706 |
+
|
| 707 |
+
# Black outline/strong edge protection: nearly fully opaque
|
| 708 |
+
black_edge_threshold = 60 # black edge threshold
|
| 709 |
+
gradient_threshold = 25 # gradient threshold
|
| 710 |
+
strong_edges = (orig_gray < black_edge_threshold) & (gradient_mag > gradient_threshold) & (mask_array > 10)
|
| 711 |
+
alpha[strong_edges] = np.maximum(alpha[strong_edges], 0.995) # black edge alpha
|
| 712 |
+
|
| 713 |
+
logger.info(f"🛡️ Protection applied - High conf: {high_confidence_pixels.sum()}, Strong edges: {strong_edges.sum()}")
|
| 714 |
+
|
| 715 |
+
# Apply edge alpha binarization to eliminate semi-transparent artifacts
|
| 716 |
+
alpha = self._binarize_edge_alpha(
|
| 717 |
+
alpha,
|
| 718 |
+
mask_array,
|
| 719 |
+
orig_array,
|
| 720 |
+
threshold=self.ALPHA_BINARIZE_THRESHOLD
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# Final blending
|
| 724 |
+
alpha_3d = alpha[:, :, np.newaxis]
|
| 725 |
+
result_linear = orig_linear * alpha_3d + bg_linear * (1 - alpha_3d)
|
| 726 |
+
result_srgb = linear_to_srgb(result_linear)
|
| 727 |
+
result_array = (result_srgb * 255).astype(np.uint8)
|
| 728 |
+
|
| 729 |
+
# Final edge cleanup pass
|
| 730 |
+
result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
|
| 731 |
+
|
| 732 |
+
# Protect core foreground from any background influence
|
| 733 |
+
# This ensures faces and bodies retain original colors
|
| 734 |
+
result_array = self._protect_foreground_core(
|
| 735 |
+
result_array,
|
| 736 |
+
np.array(original_image, dtype=np.uint8), # Use original unprocessed image
|
| 737 |
+
mask_array,
|
| 738 |
+
protection_threshold=self.FOREGROUND_PROTECTION_THRESHOLD
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Store debug information (for debug output)
|
| 742 |
+
self._debug_info = {
|
| 743 |
+
'orig_bg_color_rgb': orig_bg_color_rgb,
|
| 744 |
+
'fg_rep_color_rgb': fg_rep_color_rgb,
|
| 745 |
+
'orig_bg_color_lab': orig_bg_color_lab,
|
| 746 |
+
'fg_rep_color_lab': fg_rep_color_lab,
|
| 747 |
+
'ring_zone': ring_zone,
|
| 748 |
+
'fg_core': fg_core,
|
| 749 |
+
'alpha_final': alpha
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
return Image.fromarray(result_array)
|
| 753 |
+
|
| 754 |
+
def create_debug_images(
|
| 755 |
+
self,
|
| 756 |
+
original_image: Image.Image,
|
| 757 |
+
generated_background: Image.Image,
|
| 758 |
+
combination_mask: Image.Image,
|
| 759 |
+
combined_image: Image.Image
|
| 760 |
+
) -> Dict[str, Image.Image]:
|
| 761 |
+
"""
|
| 762 |
+
Generate debug images: (a) Final mask grayscale (b) Alpha heatmap (c) Ring visualization overlay
|
| 763 |
+
"""
|
| 764 |
+
debug_images = {}
|
| 765 |
+
|
| 766 |
+
# Final Mask grayscale
|
| 767 |
+
debug_images["mask_gray"] = combination_mask.convert('L')
|
| 768 |
+
|
| 769 |
+
# Alpha Heatmap
|
| 770 |
+
mask_array = np.array(combination_mask.convert('L'))
|
| 771 |
+
heatmap_colored = cv2.applyColorMap(mask_array, cv2.COLORMAP_JET)
|
| 772 |
+
heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
|
| 773 |
+
debug_images["alpha_heatmap"] = Image.fromarray(heatmap_rgb)
|
| 774 |
+
|
| 775 |
+
# Ring visualization overlay - show ring areas on original image
|
| 776 |
+
if hasattr(self, '_debug_info') and 'ring_zone' in self._debug_info:
|
| 777 |
+
ring_zone = self._debug_info['ring_zone']
|
| 778 |
+
orig_array = np.array(original_image)
|
| 779 |
+
ring_overlay = orig_array.copy()
|
| 780 |
+
|
| 781 |
+
# Mark ring areas with red semi-transparent overlay
|
| 782 |
+
ring_overlay[ring_zone] = ring_overlay[ring_zone] * 0.7 + np.array([255, 0, 0]) * 0.3
|
| 783 |
+
debug_images["ring_visualization"] = Image.fromarray(ring_overlay.astype(np.uint8))
|
| 784 |
+
else:
|
| 785 |
+
# If no ring information, use original image
|
| 786 |
+
debug_images["ring_visualization"] = original_image
|
| 787 |
+
|
| 788 |
+
# Adaptive strength heatmap - visualize per-pixel correction strength
|
| 789 |
+
if hasattr(self, '_adaptive_strength_map') and self._adaptive_strength_map is not None:
|
| 790 |
+
# Normalize adaptive strength to 0-255 for visualization
|
| 791 |
+
strength_map = self._adaptive_strength_map
|
| 792 |
+
if strength_map.max() > 0:
|
| 793 |
+
normalized_strength = (strength_map / strength_map.max() * 255).astype(np.uint8)
|
| 794 |
+
else:
|
| 795 |
+
normalized_strength = np.zeros_like(strength_map, dtype=np.uint8)
|
| 796 |
+
|
| 797 |
+
# Apply colormap
|
| 798 |
+
strength_heatmap = cv2.applyColorMap(normalized_strength, cv2.COLORMAP_VIRIDIS)
|
| 799 |
+
strength_heatmap_rgb = cv2.cvtColor(strength_heatmap, cv2.COLOR_BGR2RGB)
|
| 800 |
+
debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
|
| 801 |
+
|
| 802 |
+
return debug_images
|
mask_generator.py
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image, ImageFilter, ImageDraw
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
| 7 |
+
import io
|
| 8 |
+
import gc
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoModelForImageSegmentation
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from rembg import remove, new_session
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
logger.setLevel(logging.INFO)
|
| 16 |
+
|
| 17 |
+
class MaskGenerator:
|
| 18 |
+
"""
|
| 19 |
+
Intelligent mask generation using deep learning models with traditional fallback.
|
| 20 |
+
Priority: BiRefNet > U²-Net (rembg) > Traditional gradient-based methods
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, max_image_size: int = 1024, device: str = "auto"):
|
| 24 |
+
self.max_image_size = max_image_size
|
| 25 |
+
self.device = self._setup_device(device)
|
| 26 |
+
|
| 27 |
+
# BiRefNet model (lazy loading)
|
| 28 |
+
self._birefnet_model = None
|
| 29 |
+
self._birefnet_transform = None
|
| 30 |
+
|
| 31 |
+
# Log initialization
|
| 32 |
+
logger.info(f"🎭 MaskGenerator initialized on {self.device}")
|
| 33 |
+
|
| 34 |
+
def _setup_device(self, device: str) -> str:
|
| 35 |
+
"""Setup computation device"""
|
| 36 |
+
if device == "auto":
|
| 37 |
+
if torch.cuda.is_available():
|
| 38 |
+
return "cuda"
|
| 39 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 40 |
+
return "mps"
|
| 41 |
+
return "cpu"
|
| 42 |
+
return device
|
| 43 |
+
|
| 44 |
+
def _load_birefnet_model(self) -> bool:
|
| 45 |
+
"""
|
| 46 |
+
Lazy load BiRefNet model for memory efficiency.
|
| 47 |
+
Returns True if model loaded successfully, False otherwise.
|
| 48 |
+
"""
|
| 49 |
+
if self._birefnet_model is not None:
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
logger.info("📥 Loading BiRefNet model (ZhengPeng7/BiRefNet)...")
|
| 54 |
+
|
| 55 |
+
# Load model with fp16 for memory efficiency on GPU
|
| 56 |
+
dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 57 |
+
|
| 58 |
+
self._birefnet_model = AutoModelForImageSegmentation.from_pretrained(
|
| 59 |
+
"ZhengPeng7/BiRefNet",
|
| 60 |
+
trust_remote_code=True,
|
| 61 |
+
torch_dtype=dtype
|
| 62 |
+
)
|
| 63 |
+
self._birefnet_model.to(self.device)
|
| 64 |
+
self._birefnet_model.eval()
|
| 65 |
+
|
| 66 |
+
# Define preprocessing transform
|
| 67 |
+
self._birefnet_transform = transforms.Compose([
|
| 68 |
+
transforms.Resize((1024, 1024)),
|
| 69 |
+
transforms.ToTensor(),
|
| 70 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 71 |
+
])
|
| 72 |
+
|
| 73 |
+
logger.info("✅ BiRefNet model loaded successfully")
|
| 74 |
+
return True
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"❌ Failed to load BiRefNet: {e}")
|
| 78 |
+
self._birefnet_model = None
|
| 79 |
+
self._birefnet_transform = None
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
def _unload_birefnet_model(self):
|
| 83 |
+
"""Unload BiRefNet model to free memory"""
|
| 84 |
+
if self._birefnet_model is not None:
|
| 85 |
+
del self._birefnet_model
|
| 86 |
+
self._birefnet_model = None
|
| 87 |
+
self._birefnet_transform = None
|
| 88 |
+
|
| 89 |
+
if torch.cuda.is_available():
|
| 90 |
+
torch.cuda.empty_cache()
|
| 91 |
+
gc.collect()
|
| 92 |
+
logger.info("🧹 BiRefNet model unloaded")
|
| 93 |
+
|
| 94 |
+
def apply_guided_filter(
|
| 95 |
+
self,
|
| 96 |
+
mask: np.ndarray,
|
| 97 |
+
guide_image: Image.Image,
|
| 98 |
+
radius: int = 8,
|
| 99 |
+
eps: float = 0.01
|
| 100 |
+
) -> np.ndarray:
|
| 101 |
+
"""
|
| 102 |
+
Apply guided filter to mask for edge-preserving smoothing.
|
| 103 |
+
Falls back to Gaussian blur if guided filter is not available.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
mask: Input mask as numpy array (0-255)
|
| 107 |
+
guide_image: Original image to use as guide
|
| 108 |
+
radius: Filter radius (larger = more smoothing)
|
| 109 |
+
eps: Regularization parameter (smaller = more edge-preserving)
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Filtered mask as numpy array (0-255)
|
| 113 |
+
"""
|
| 114 |
+
try:
|
| 115 |
+
# Convert guide image to grayscale
|
| 116 |
+
guide_gray = np.array(guide_image.convert('L')).astype(np.float32) / 255.0
|
| 117 |
+
mask_float = mask.astype(np.float32) / 255.0
|
| 118 |
+
|
| 119 |
+
logger.info(f"🔧 Applying guided filter (radius={radius}, eps={eps})")
|
| 120 |
+
|
| 121 |
+
# Apply guided filter
|
| 122 |
+
filtered = cv2.ximgproc.guidedFilter(
|
| 123 |
+
guide=guide_gray,
|
| 124 |
+
src=mask_float,
|
| 125 |
+
radius=radius,
|
| 126 |
+
eps=eps
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Convert back to 0-255 range
|
| 130 |
+
result = (np.clip(filtered, 0, 1) * 255).astype(np.uint8)
|
| 131 |
+
logger.info("✅ Guided filter applied successfully")
|
| 132 |
+
return result
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"❌ Guided filter failed: {e}, using original mask")
|
| 136 |
+
return mask
|
| 137 |
+
|
| 138 |
+
def try_birefnet_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
|
| 139 |
+
"""
|
| 140 |
+
Generate foreground mask using BiRefNet model.
|
| 141 |
+
BiRefNet provides high-quality segmentation with clean edges.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
original_image: Input PIL Image
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
PIL Image (L mode) mask or None if failed
|
| 148 |
+
"""
|
| 149 |
+
try:
|
| 150 |
+
# Lazy load model
|
| 151 |
+
if not self._load_birefnet_model():
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
logger.info("🤖 Starting BiRefNet foreground extraction...")
|
| 155 |
+
original_size = original_image.size
|
| 156 |
+
|
| 157 |
+
# Convert to RGB if needed
|
| 158 |
+
if original_image.mode != 'RGB':
|
| 159 |
+
image_rgb = original_image.convert('RGB')
|
| 160 |
+
else:
|
| 161 |
+
image_rgb = original_image
|
| 162 |
+
|
| 163 |
+
# Preprocess image
|
| 164 |
+
input_tensor = self._birefnet_transform(image_rgb).unsqueeze(0)
|
| 165 |
+
|
| 166 |
+
# Move to device with appropriate dtype
|
| 167 |
+
if self.device == "cuda":
|
| 168 |
+
input_tensor = input_tensor.to(self.device, dtype=torch.float16)
|
| 169 |
+
else:
|
| 170 |
+
input_tensor = input_tensor.to(self.device)
|
| 171 |
+
|
| 172 |
+
# Run inference
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
outputs = self._birefnet_model(input_tensor)
|
| 175 |
+
|
| 176 |
+
# BiRefNet outputs a list, get the final prediction
|
| 177 |
+
if isinstance(outputs, (list, tuple)):
|
| 178 |
+
pred = outputs[-1]
|
| 179 |
+
else:
|
| 180 |
+
pred = outputs
|
| 181 |
+
|
| 182 |
+
# Sigmoid to get probability map
|
| 183 |
+
pred = torch.sigmoid(pred)
|
| 184 |
+
|
| 185 |
+
# Convert to numpy
|
| 186 |
+
pred_np = pred.squeeze().cpu().numpy()
|
| 187 |
+
|
| 188 |
+
# Convert to 0-255 range
|
| 189 |
+
mask_array = (pred_np * 255).astype(np.uint8)
|
| 190 |
+
|
| 191 |
+
# Resize back to original size
|
| 192 |
+
mask_pil = Image.fromarray(mask_array, mode='L')
|
| 193 |
+
mask_pil = mask_pil.resize(original_size, Image.LANCZOS)
|
| 194 |
+
mask_array = np.array(mask_pil)
|
| 195 |
+
|
| 196 |
+
# Quality check
|
| 197 |
+
mean_val = mask_array.mean()
|
| 198 |
+
nonzero_ratio = np.count_nonzero(mask_array > 50) / mask_array.size
|
| 199 |
+
|
| 200 |
+
logger.info(f"📊 BiRefNet mask stats - Mean: {mean_val:.1f}, Coverage: {nonzero_ratio:.1%}")
|
| 201 |
+
|
| 202 |
+
if mean_val < 10:
|
| 203 |
+
logger.warning("⚠️ BiRefNet mask too weak, falling back")
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
if nonzero_ratio < 0.03:
|
| 207 |
+
logger.warning("⚠️ BiRefNet foreground coverage too low, falling back")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
# Light post-processing for edge refinement
|
| 211 |
+
# Use morphological operations to clean up
|
| 212 |
+
kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 213 |
+
mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_CLOSE, kernel_small)
|
| 214 |
+
|
| 215 |
+
logger.info("✅ BiRefNet mask generation successful!")
|
| 216 |
+
return Image.fromarray(mask_array, mode='L')
|
| 217 |
+
|
| 218 |
+
except torch.cuda.OutOfMemoryError:
|
| 219 |
+
logger.error("❌ BiRefNet: GPU memory exhausted")
|
| 220 |
+
self._unload_birefnet_model()
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(f"❌ BiRefNet mask generation failed: {e}")
|
| 225 |
+
import traceback
|
| 226 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
def try_deep_learning_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
|
| 230 |
+
"""
|
| 231 |
+
Intelligent foreground extraction with model priority:
|
| 232 |
+
1. BiRefNet (best quality, clean edges)
|
| 233 |
+
2. U²-Net via rembg (good fallback)
|
| 234 |
+
3. Return None to trigger traditional methods
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
original_image: Input PIL Image
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
PIL Image (L mode) mask or None if all methods failed
|
| 241 |
+
"""
|
| 242 |
+
# Priority 1: Try BiRefNet first
|
| 243 |
+
logger.info("🤖 Attempting BiRefNet mask generation...")
|
| 244 |
+
birefnet_mask = self.try_birefnet_mask(original_image)
|
| 245 |
+
if birefnet_mask is not None:
|
| 246 |
+
logger.info("✅ Using BiRefNet generated mask")
|
| 247 |
+
return birefnet_mask
|
| 248 |
+
|
| 249 |
+
# Priority 2: Fallback to rembg (U²-Net)
|
| 250 |
+
logger.info("🔄 BiRefNet unavailable/failed, trying rembg...")
|
| 251 |
+
try:
|
| 252 |
+
logger.info("🤖 Starting rembg foreground extraction")
|
| 253 |
+
|
| 254 |
+
# Try u2net first (better for cartoons/objects like Snoopy)
|
| 255 |
+
try:
|
| 256 |
+
session = new_session('u2net')
|
| 257 |
+
logger.info("✅ Using u2net model")
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.warning(f"u2net failed ({e}), trying u2net_human_seg")
|
| 260 |
+
try:
|
| 261 |
+
session = new_session('u2net_human_seg')
|
| 262 |
+
logger.info("✅ Using u2net_human_seg model")
|
| 263 |
+
except Exception as e2:
|
| 264 |
+
logger.error(f"All rembg models failed: {e2}")
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
# Convert image to bytes for rembg
|
| 268 |
+
img_byte_arr = io.BytesIO()
|
| 269 |
+
original_image.save(img_byte_arr, format='PNG')
|
| 270 |
+
img_byte_arr = img_byte_arr.getvalue()
|
| 271 |
+
logger.info(f"📷 Image size: {len(img_byte_arr)} bytes")
|
| 272 |
+
|
| 273 |
+
# Perform background removal
|
| 274 |
+
result = remove(img_byte_arr, session=session)
|
| 275 |
+
result_img = Image.open(io.BytesIO(result)).convert('RGBA')
|
| 276 |
+
alpha_channel = result_img.split()[-1]
|
| 277 |
+
alpha_array = np.array(alpha_channel)
|
| 278 |
+
|
| 279 |
+
logger.info(f"📊 Raw alpha stats - Mean: {alpha_array.mean():.1f}, Min: {alpha_array.min()}, Max: {alpha_array.max()}")
|
| 280 |
+
|
| 281 |
+
# Step 1: Light smoothing to reduce noise but preserve edges
|
| 282 |
+
alpha_smoothed = cv2.GaussianBlur(alpha_array, (3, 3), 0.8)
|
| 283 |
+
|
| 284 |
+
# Step 2: Contrast stretching to utilize full range
|
| 285 |
+
alpha_stretched = cv2.normalize(alpha_smoothed, None, 0, 255, cv2.NORM_MINMAX)
|
| 286 |
+
|
| 287 |
+
# Step 3: CRITICAL FIX - More aggressive foreground preservation
|
| 288 |
+
# Instead of hard threshold, use adaptive approach
|
| 289 |
+
|
| 290 |
+
# Find the main subject area (high confidence regions)
|
| 291 |
+
high_confidence = alpha_stretched > 180
|
| 292 |
+
medium_confidence = (alpha_stretched > 60) & (alpha_stretched <= 180)
|
| 293 |
+
low_confidence = (alpha_stretched > 15) & (alpha_stretched <= 60)
|
| 294 |
+
|
| 295 |
+
# Create final mask with better extremity handling
|
| 296 |
+
final_alpha = np.zeros_like(alpha_stretched)
|
| 297 |
+
|
| 298 |
+
# High confidence areas - keep at full opacity
|
| 299 |
+
final_alpha[high_confidence] = 255
|
| 300 |
+
|
| 301 |
+
# Medium confidence - boost significantly
|
| 302 |
+
final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
|
| 303 |
+
|
| 304 |
+
# Low confidence - moderate boost (catches faint extremities)
|
| 305 |
+
final_alpha[low_confidence] = np.clip(alpha_stretched[low_confidence] * 2.5, 120, 199)
|
| 306 |
+
|
| 307 |
+
# Morphological operations to connect disconnected parts (hands, feet, tail)
|
| 308 |
+
kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 309 |
+
kernel_medium = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 310 |
+
|
| 311 |
+
# Close small gaps (helps connect separated body parts)
|
| 312 |
+
final_alpha = cv2.morphologyEx(final_alpha, cv2.MORPH_CLOSE, kernel_small, iterations=1)
|
| 313 |
+
|
| 314 |
+
# Light dilation to ensure nothing gets cut off
|
| 315 |
+
final_alpha = cv2.dilate(final_alpha, kernel_small, iterations=1)
|
| 316 |
+
|
| 317 |
+
logger.info(f"📊 Final alpha stats - Mean: {final_alpha.mean():.1f}, Min: {final_alpha.min()}, Max: {final_alpha.max()}")
|
| 318 |
+
|
| 319 |
+
# Quality check - but be more lenient for cartoon characters
|
| 320 |
+
if final_alpha.mean() < 10:
|
| 321 |
+
logger.warning("⚠️ Alpha still too weak, falling back to traditional method")
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
# Enhanced post-processing for cartoon characters
|
| 325 |
+
is_cartoon = self._detect_cartoon_character(original_image, final_alpha)
|
| 326 |
+
|
| 327 |
+
if is_cartoon:
|
| 328 |
+
logger.info("🎭 Detected cartoon/character image, applying specialized processing")
|
| 329 |
+
final_alpha = self._enhance_cartoon_mask(original_image, final_alpha)
|
| 330 |
+
|
| 331 |
+
# Count non-zero pixels to ensure we have substantial foreground
|
| 332 |
+
foreground_pixels = np.count_nonzero(final_alpha > 50)
|
| 333 |
+
total_pixels = final_alpha.size
|
| 334 |
+
foreground_ratio = foreground_pixels / total_pixels
|
| 335 |
+
logger.info(f"📊 Foreground coverage: {foreground_ratio:.1%} of image")
|
| 336 |
+
|
| 337 |
+
if foreground_ratio < 0.05: # Less than 5% is probably too little
|
| 338 |
+
logger.warning("⚠️ Very low foreground coverage, falling back to traditional method")
|
| 339 |
+
return None
|
| 340 |
+
|
| 341 |
+
mask = Image.fromarray(final_alpha.astype(np.uint8), mode='L')
|
| 342 |
+
logger.info("✅ Enhanced rembg mask generation successful!")
|
| 343 |
+
return mask
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logger.error(f"❌ Deep learning mask extraction failed: {e}")
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
def _detect_cartoon_character(self, original_image: Image.Image, alpha_mask: np.ndarray) -> bool:
|
| 350 |
+
"""
|
| 351 |
+
Detect if image is cartoon/line art (heuristic approach)
|
| 352 |
+
"""
|
| 353 |
+
try:
|
| 354 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 355 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 356 |
+
|
| 357 |
+
# Calculate edge density (cartoons usually have more clear edges)
|
| 358 |
+
edges = cv2.Canny(gray, 50, 150)
|
| 359 |
+
edge_density = np.count_nonzero(edges) / max(edges.size, 1) # Avoid division by zero
|
| 360 |
+
|
| 361 |
+
# Calculate color complexity (cartoons usually have fewer colors) - optimize memory usage
|
| 362 |
+
h, w, c = img_array.shape
|
| 363 |
+
if h * w > 100000: # If image is too large, resize for processing
|
| 364 |
+
small_img = cv2.resize(img_array, (200, 200))
|
| 365 |
+
else:
|
| 366 |
+
small_img = img_array
|
| 367 |
+
|
| 368 |
+
unique_colors = len(np.unique(small_img.reshape(-1, 3), axis=0))
|
| 369 |
+
total_pixels = small_img.shape[0] * small_img.shape[1]
|
| 370 |
+
color_simplicity = unique_colors < (total_pixels * 0.1)
|
| 371 |
+
|
| 372 |
+
# Check for obvious black outlines
|
| 373 |
+
dark_pixels_ratio = np.count_nonzero(gray < 50) / max(gray.size, 1) # Avoid division by zero
|
| 374 |
+
has_black_outline = dark_pixels_ratio > 0.05
|
| 375 |
+
|
| 376 |
+
# Comprehensive judgment: high edge density + color simplicity + black outline = likely cartoon
|
| 377 |
+
is_cartoon = (edge_density > 0.05) and (color_simplicity or has_black_outline)
|
| 378 |
+
|
| 379 |
+
logger.info(f"🔍 Cartoon detection - Edge density: {edge_density:.3f}, Color simplicity: {color_simplicity}, Black outline: {has_black_outline} -> Cartoon: {is_cartoon}")
|
| 380 |
+
return is_cartoon
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
import traceback
|
| 384 |
+
logger.error(f"❌ Cartoon detection failed: {e}")
|
| 385 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 386 |
+
print(f"❌ CARTOON DETECTION ERROR: {e}")
|
| 387 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 388 |
+
return False
|
| 389 |
+
|
| 390 |
+
def _enhance_cartoon_mask(self, original_image: Image.Image, alpha_mask: np.ndarray) -> np.ndarray:
|
| 391 |
+
"""
|
| 392 |
+
Enhanced mask processing for cartoon characters
|
| 393 |
+
"""
|
| 394 |
+
try:
|
| 395 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 396 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 397 |
+
enhanced_alpha = alpha_mask.copy()
|
| 398 |
+
|
| 399 |
+
# Step 1: Black outline enhancement - find black outlines and enhance their alpha
|
| 400 |
+
th_dark = 80 # Adjustable parameter: black threshold
|
| 401 |
+
black_outline = gray < th_dark
|
| 402 |
+
|
| 403 |
+
# Dilate black outline region by 1px
|
| 404 |
+
kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # Adjustable parameter: dilation kernel size
|
| 405 |
+
black_outline_dilated = cv2.dilate(black_outline.astype(np.uint8), kernel_dilate, iterations=1)
|
| 406 |
+
|
| 407 |
+
# Set black outline region alpha directly to 255
|
| 408 |
+
enhanced_alpha[black_outline_dilated > 0] = 255
|
| 409 |
+
logger.info(f"🖤 Black outline enhanced: {np.count_nonzero(black_outline_dilated)} pixels")
|
| 410 |
+
|
| 411 |
+
# Step 2: Simplified internal enhancement - process white fill areas within outlines
|
| 412 |
+
# Find high confidence regions (alpha ≥ 160)
|
| 413 |
+
high_confidence = enhanced_alpha >= 160
|
| 414 |
+
|
| 415 |
+
# Apply close operation on high confidence regions to connect separated parts
|
| 416 |
+
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Adjustable parameter: close kernel size
|
| 417 |
+
high_confidence_closed = cv2.morphologyEx(high_confidence.astype(np.uint8), cv2.MORPH_CLOSE, kernel_close, iterations=1)
|
| 418 |
+
|
| 419 |
+
# Simplified approach: directly enhance medium confidence regions without complex flood fill
|
| 420 |
+
# Find medium/low confidence regions surrounded by high confidence regions
|
| 421 |
+
medium_confidence = (enhanced_alpha >= 80) & (enhanced_alpha < 160)
|
| 422 |
+
|
| 423 |
+
# Dilate high confidence region to include more internal areas
|
| 424 |
+
kernel_dilate_internal = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 425 |
+
high_confidence_expanded = cv2.dilate(high_confidence_closed, kernel_dilate_internal, iterations=1)
|
| 426 |
+
|
| 427 |
+
# Medium confidence pixels within expanded high confidence areas are considered internal fill
|
| 428 |
+
internal_fill_regions = medium_confidence & (high_confidence_expanded > 0)
|
| 429 |
+
|
| 430 |
+
# Enhance alpha of these internal fill regions to at least 220
|
| 431 |
+
min_alpha_for_fill = 220 # Adjustable parameter: minimum alpha for internal fill
|
| 432 |
+
enhanced_alpha[internal_fill_regions] = np.maximum(enhanced_alpha[internal_fill_regions], min_alpha_for_fill)
|
| 433 |
+
|
| 434 |
+
logger.info(f"🤍 Internal fill regions enhanced: {np.count_nonzero(internal_fill_regions)} pixels")
|
| 435 |
+
logger.info(f"📊 Enhanced alpha stats - Mean: {enhanced_alpha.mean():.1f}, Min: {enhanced_alpha.min()}, Max: {enhanced_alpha.max()}")
|
| 436 |
+
|
| 437 |
+
return enhanced_alpha
|
| 438 |
+
|
| 439 |
+
except Exception as e:
|
| 440 |
+
import traceback
|
| 441 |
+
logger.error(f"❌ Cartoon mask enhancement failed: {e}")
|
| 442 |
+
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 443 |
+
print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
|
| 444 |
+
print(f"Traceback: {traceback.format_exc()}")
|
| 445 |
+
return alpha_mask
|
| 446 |
+
|
| 447 |
+
def _adjust_mask_for_scene_focus(self, mask: Image.Image, original_image: Image.Image) -> Image.Image:
|
| 448 |
+
"""
|
| 449 |
+
Adjust mask for scene focus mode to include nearby objects like chairs, furniture
|
| 450 |
+
"""
|
| 451 |
+
try:
|
| 452 |
+
logger.info("🏠 Adjusting mask for scene focus mode...")
|
| 453 |
+
|
| 454 |
+
mask_array = np.array(mask)
|
| 455 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 456 |
+
|
| 457 |
+
# Expand mask to include nearby objects
|
| 458 |
+
# Use larger dilation kernel to include furniture/objects
|
| 459 |
+
kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 460 |
+
expanded_mask = cv2.dilate(mask_array, kernel_large, iterations=2)
|
| 461 |
+
|
| 462 |
+
# Find contours in the expanded area to detect objects
|
| 463 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 464 |
+
edges = cv2.Canny(gray, 30, 100)
|
| 465 |
+
|
| 466 |
+
# Apply edge detection only in the expanded region
|
| 467 |
+
expanded_region = (expanded_mask > 0) & (mask_array == 0)
|
| 468 |
+
object_edges = np.zeros_like(edges)
|
| 469 |
+
object_edges[expanded_region] = edges[expanded_region]
|
| 470 |
+
|
| 471 |
+
# Close gaps to form complete objects
|
| 472 |
+
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 473 |
+
object_mask = cv2.morphologyEx(object_edges, cv2.MORPH_CLOSE, kernel_close)
|
| 474 |
+
object_mask = cv2.dilate(object_mask, kernel_close, iterations=1)
|
| 475 |
+
|
| 476 |
+
# Combine with original mask
|
| 477 |
+
final_mask = np.maximum(mask_array, object_mask)
|
| 478 |
+
|
| 479 |
+
logger.info("✅ Scene focus adjustment completed")
|
| 480 |
+
return Image.fromarray(final_mask)
|
| 481 |
+
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logger.error(f"❌ Scene focus adjustment failed: {e}")
|
| 484 |
+
return mask
|
| 485 |
+
|
| 486 |
+
def create_gradient_based_mask(self, original_image: Image.Image, mode: str = "center", focus_mode: str = "person") -> Image.Image:
|
| 487 |
+
"""
|
| 488 |
+
Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
|
| 489 |
+
Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
|
| 490 |
+
"""
|
| 491 |
+
width, height = original_image.size
|
| 492 |
+
logger.info(f"🎯 Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}")
|
| 493 |
+
|
| 494 |
+
if mode == "center":
|
| 495 |
+
# Try using deep learning models for intelligent foreground extraction
|
| 496 |
+
logger.info("🤖 Attempting deep learning mask generation...")
|
| 497 |
+
dl_mask = self.try_deep_learning_mask(original_image)
|
| 498 |
+
if dl_mask is not None:
|
| 499 |
+
logger.info("✅ Using deep learning generated mask")
|
| 500 |
+
# Apply focus mode adjustments to deep learning mask
|
| 501 |
+
if focus_mode == "scene":
|
| 502 |
+
dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
|
| 503 |
+
return dl_mask
|
| 504 |
+
|
| 505 |
+
# Fallback to traditional method
|
| 506 |
+
logger.info("🔄 Deep learning failed, using traditional gradient-based method")
|
| 507 |
+
img_array = np.array(original_image.convert('RGB'))
|
| 508 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 509 |
+
|
| 510 |
+
# First-order derivatives: use Sobel operator for edge detection
|
| 511 |
+
grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 512 |
+
grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 513 |
+
gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
|
| 514 |
+
|
| 515 |
+
# Second-order derivatives: use Laplacian operator for texture change detection
|
| 516 |
+
laplacian = cv2.Laplacian(gray, cv2.CV_64F, ksize=3)
|
| 517 |
+
laplacian_abs = np.abs(laplacian)
|
| 518 |
+
|
| 519 |
+
# Combine first and second order derivatives
|
| 520 |
+
combined_edges = gradient_magnitude * 0.7 + laplacian_abs * 0.3
|
| 521 |
+
combined_edges = (combined_edges / np.max(combined_edges)) * 255
|
| 522 |
+
|
| 523 |
+
# Threshold processing to find strong edges
|
| 524 |
+
_, edge_binary = cv2.threshold(combined_edges.astype(np.uint8), 20, 255, cv2.THRESH_BINARY)
|
| 525 |
+
|
| 526 |
+
# Morphological operations to connect edges
|
| 527 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 528 |
+
edge_binary = cv2.morphologyEx(edge_binary, cv2.MORPH_CLOSE, kernel)
|
| 529 |
+
|
| 530 |
+
# Find contours and create mask
|
| 531 |
+
contours, _ = cv2.findContours(edge_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 532 |
+
|
| 533 |
+
if contours:
|
| 534 |
+
# Find largest contour (main subject)
|
| 535 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
| 536 |
+
contour_mask = np.zeros((height, width), dtype=np.uint8)
|
| 537 |
+
cv2.fillPoly(contour_mask, [largest_contour], 255)
|
| 538 |
+
|
| 539 |
+
# Create foreground enhancement mask: specially protect dark regions
|
| 540 |
+
dark_mask = (gray < 90).astype(np.uint8) * 255
|
| 541 |
+
morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 542 |
+
dark_mask = cv2.morphologyEx(dark_mask, cv2.MORPH_CLOSE, morph_kernel, iterations=1)
|
| 543 |
+
dark_mask = cv2.dilate(dark_mask, morph_kernel, iterations=2)
|
| 544 |
+
contour_mask = cv2.bitwise_or(contour_mask, dark_mask)
|
| 545 |
+
|
| 546 |
+
# Get core foreground: clean holes and fill gaps
|
| 547 |
+
close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 548 |
+
core_mask = cv2.morphologyEx(contour_mask, cv2.MORPH_CLOSE, close_kernel, iterations=1)
|
| 549 |
+
|
| 550 |
+
open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 551 |
+
core_mask = cv2.morphologyEx(core_mask, cv2.MORPH_OPEN, open_kernel, iterations=1)
|
| 552 |
+
|
| 553 |
+
# Convert to binary core (0/255)
|
| 554 |
+
_, core_binary = cv2.threshold(core_mask, 127, 255, cv2.THRESH_BINARY)
|
| 555 |
+
|
| 556 |
+
# Keep only slight dilation to avoid foreground being eaten
|
| 557 |
+
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 558 |
+
core_binary = cv2.dilate(core_binary, dilate_kernel, iterations=1)
|
| 559 |
+
|
| 560 |
+
# Distance transform feathering: shrink feathering range for sharp edges
|
| 561 |
+
FEATHER_PX = 4
|
| 562 |
+
|
| 563 |
+
# Calculate distance transform
|
| 564 |
+
core_float = core_binary.astype(np.float32) / 255.0
|
| 565 |
+
distances = cv2.distanceTransform((1 - core_float).astype(np.uint8), cv2.DIST_L2, 5)
|
| 566 |
+
|
| 567 |
+
# Create feathering mask: 0→FEATHER_PX linear mapping to 1→0
|
| 568 |
+
feather_mask = np.ones_like(distances)
|
| 569 |
+
edge_region = (distances > 0) & (distances <= FEATHER_PX)
|
| 570 |
+
feather_mask[edge_region] = 1.0 - (distances[edge_region] / FEATHER_PX)
|
| 571 |
+
feather_mask[distances > FEATHER_PX] = 0.0
|
| 572 |
+
|
| 573 |
+
# Apply double-smoothstep curve: make transition steeper, reduce semi-transparent halos
|
| 574 |
+
def double_smoothstep(t):
|
| 575 |
+
t = np.clip(t, 0, 1)
|
| 576 |
+
s1 = t * t * (3 - 2 * t)
|
| 577 |
+
return s1 * s1 * (3 - 2 * s1) # Equivalent to t^3 (10 - 15t + 6t^2)
|
| 578 |
+
|
| 579 |
+
# Combine core with feathering: core area keeps 255, edges use double_smoothstep feathering
|
| 580 |
+
final_alpha = np.zeros_like(distances)
|
| 581 |
+
final_alpha[core_binary > 127] = 1.0 # Core area
|
| 582 |
+
final_alpha[edge_region] = double_smoothstep(feather_mask[edge_region]) # Feathering area
|
| 583 |
+
|
| 584 |
+
# Convert to 0-255 range
|
| 585 |
+
final_mask = (final_alpha * 255).astype(np.uint8)
|
| 586 |
+
|
| 587 |
+
# Apply guided filter for edge-preserving smoothing
|
| 588 |
+
final_mask = self.apply_guided_filter(final_mask, original_image, radius=8, eps=0.01)
|
| 589 |
+
|
| 590 |
+
mask = Image.fromarray(final_mask)
|
| 591 |
+
else:
|
| 592 |
+
# Backup plan: use large ellipse
|
| 593 |
+
mask = Image.new('L', (width, height), 0)
|
| 594 |
+
draw = ImageDraw.Draw(mask)
|
| 595 |
+
center_x, center_y = width // 2, height // 2
|
| 596 |
+
width_radius = int(width * 0.45)
|
| 597 |
+
height_radius = int(width * 0.48)
|
| 598 |
+
draw.ellipse([
|
| 599 |
+
center_x - width_radius, center_y - height_radius,
|
| 600 |
+
center_x + width_radius, center_y + height_radius
|
| 601 |
+
], fill=255)
|
| 602 |
+
# Apply guided filter instead of Gaussian blur
|
| 603 |
+
mask_array = np.array(mask)
|
| 604 |
+
mask_array = self.apply_guided_filter(mask_array, original_image, radius=10, eps=0.02)
|
| 605 |
+
mask = Image.fromarray(mask_array)
|
| 606 |
+
|
| 607 |
+
elif mode == "left_half":
|
| 608 |
+
# Keep original logic unchanged - ensure Snoopy and other functions work normally
|
| 609 |
+
mask = Image.new('L', (width, height), 0)
|
| 610 |
+
mask_array = np.array(mask)
|
| 611 |
+
mask_array[:, :width//2] = 255
|
| 612 |
+
|
| 613 |
+
transition_zone = width // 10
|
| 614 |
+
for i in range(transition_zone):
|
| 615 |
+
x_pos = width//2 + i
|
| 616 |
+
if x_pos < width:
|
| 617 |
+
alpha = 255 * (1 - i / transition_zone)
|
| 618 |
+
mask_array[:, x_pos] = int(alpha)
|
| 619 |
+
|
| 620 |
+
mask = Image.fromarray(mask_array)
|
| 621 |
+
|
| 622 |
+
elif mode == "right_half":
|
| 623 |
+
# Keep original logic unchanged - ensure Snoopy and other functions work normally
|
| 624 |
+
mask = Image.new('L', (width, height), 0)
|
| 625 |
+
mask_array = np.array(mask)
|
| 626 |
+
mask_array[:, width//2:] = 255
|
| 627 |
+
|
| 628 |
+
transition_zone = width // 10
|
| 629 |
+
for i in range(transition_zone):
|
| 630 |
+
x_pos = width//2 - i - 1
|
| 631 |
+
if x_pos >= 0:
|
| 632 |
+
alpha = 255 * (1 - i / transition_zone)
|
| 633 |
+
mask_array[:, x_pos] = int(alpha)
|
| 634 |
+
|
| 635 |
+
mask = Image.fromarray(mask_array)
|
| 636 |
+
|
| 637 |
+
elif mode == "full":
|
| 638 |
+
mask = Image.new('L', (width, height), 0)
|
| 639 |
+
draw = ImageDraw.Draw(mask)
|
| 640 |
+
center_x, center_y = width // 2, height // 2
|
| 641 |
+
radius = min(width, height) // 8
|
| 642 |
+
|
| 643 |
+
draw.ellipse([
|
| 644 |
+
center_x - radius, center_y - radius,
|
| 645 |
+
center_x + radius, center_y + radius
|
| 646 |
+
], fill=255)
|
| 647 |
+
|
| 648 |
+
mask = mask.filter(ImageFilter.GaussianBlur(radius=5))
|
| 649 |
+
|
| 650 |
+
return mask
|
model_manager.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
from typing import Dict, Any, Optional, Callable
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from threading import Lock
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
logger.setLevel(logging.INFO)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ModelInfo:
|
| 15 |
+
"""Information about a registered model."""
|
| 16 |
+
name: str
|
| 17 |
+
loader: Callable[[], Any]
|
| 18 |
+
is_critical: bool = False # Critical models are not unloaded under memory pressure
|
| 19 |
+
estimated_memory_gb: float = 0.0
|
| 20 |
+
is_loaded: bool = False
|
| 21 |
+
last_used: float = 0.0
|
| 22 |
+
model_instance: Any = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ModelManager:
|
| 26 |
+
"""
|
| 27 |
+
Singleton model manager for unified model lifecycle management.
|
| 28 |
+
Handles lazy loading, caching, and intelligent memory management.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
_instance = None
|
| 32 |
+
_lock = Lock()
|
| 33 |
+
|
| 34 |
+
def __new__(cls):
|
| 35 |
+
if cls._instance is None:
|
| 36 |
+
with cls._lock:
|
| 37 |
+
if cls._instance is None:
|
| 38 |
+
cls._instance = super().__new__(cls)
|
| 39 |
+
cls._instance._initialized = False
|
| 40 |
+
return cls._instance
|
| 41 |
+
|
| 42 |
+
def __init__(self):
|
| 43 |
+
if self._initialized:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
self._models: Dict[str, ModelInfo] = {}
|
| 47 |
+
self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage
|
| 48 |
+
self._device = self._detect_device()
|
| 49 |
+
|
| 50 |
+
logger.info(f"🔧 ModelManager initialized on {self._device}")
|
| 51 |
+
self._initialized = True
|
| 52 |
+
|
| 53 |
+
def _detect_device(self) -> str:
|
| 54 |
+
"""Detect best available device."""
|
| 55 |
+
if torch.cuda.is_available():
|
| 56 |
+
return "cuda"
|
| 57 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 58 |
+
return "mps"
|
| 59 |
+
return "cpu"
|
| 60 |
+
|
| 61 |
+
def register_model(
|
| 62 |
+
self,
|
| 63 |
+
name: str,
|
| 64 |
+
loader: Callable[[], Any],
|
| 65 |
+
is_critical: bool = False,
|
| 66 |
+
estimated_memory_gb: float = 0.0
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Register a model for managed loading.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
name: Unique model identifier
|
| 73 |
+
loader: Callable that returns the loaded model
|
| 74 |
+
is_critical: If True, model won't be unloaded under memory pressure
|
| 75 |
+
estimated_memory_gb: Estimated GPU memory usage in GB
|
| 76 |
+
"""
|
| 77 |
+
if name in self._models:
|
| 78 |
+
logger.warning(f"⚠️ Model '{name}' already registered, updating")
|
| 79 |
+
|
| 80 |
+
self._models[name] = ModelInfo(
|
| 81 |
+
name=name,
|
| 82 |
+
loader=loader,
|
| 83 |
+
is_critical=is_critical,
|
| 84 |
+
estimated_memory_gb=estimated_memory_gb,
|
| 85 |
+
is_loaded=False,
|
| 86 |
+
last_used=0.0,
|
| 87 |
+
model_instance=None
|
| 88 |
+
)
|
| 89 |
+
logger.info(f"📝 Registered model: {name} (critical={is_critical}, ~{estimated_memory_gb:.1f}GB)")
|
| 90 |
+
|
| 91 |
+
def load_model(self, name: str) -> Any:
|
| 92 |
+
"""
|
| 93 |
+
Load a model by name. Returns cached instance if already loaded.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
name: Model identifier
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Loaded model instance
|
| 100 |
+
|
| 101 |
+
Raises:
|
| 102 |
+
KeyError: If model not registered
|
| 103 |
+
RuntimeError: If loading fails
|
| 104 |
+
"""
|
| 105 |
+
if name not in self._models:
|
| 106 |
+
raise KeyError(f"Model '{name}' not registered")
|
| 107 |
+
|
| 108 |
+
model_info = self._models[name]
|
| 109 |
+
|
| 110 |
+
# Return cached instance
|
| 111 |
+
if model_info.is_loaded and model_info.model_instance is not None:
|
| 112 |
+
model_info.last_used = time.time()
|
| 113 |
+
logger.debug(f"📦 Using cached model: {name}")
|
| 114 |
+
return model_info.model_instance
|
| 115 |
+
|
| 116 |
+
# Check memory pressure before loading
|
| 117 |
+
self.check_memory_pressure()
|
| 118 |
+
|
| 119 |
+
# Load the model
|
| 120 |
+
try:
|
| 121 |
+
logger.info(f"📥 Loading model: {name}")
|
| 122 |
+
start_time = time.time()
|
| 123 |
+
|
| 124 |
+
model_instance = model_info.loader()
|
| 125 |
+
|
| 126 |
+
model_info.model_instance = model_instance
|
| 127 |
+
model_info.is_loaded = True
|
| 128 |
+
model_info.last_used = time.time()
|
| 129 |
+
|
| 130 |
+
load_time = time.time() - start_time
|
| 131 |
+
logger.info(f"✅ Model '{name}' loaded in {load_time:.1f}s")
|
| 132 |
+
|
| 133 |
+
return model_instance
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"❌ Failed to load model '{name}': {e}")
|
| 137 |
+
raise RuntimeError(f"Model loading failed: {e}")
|
| 138 |
+
|
| 139 |
+
def unload_model(self, name: str):
|
| 140 |
+
"""
|
| 141 |
+
Unload a specific model to free memory.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
name: Model identifier
|
| 145 |
+
"""
|
| 146 |
+
if name not in self._models:
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
model_info = self._models[name]
|
| 150 |
+
|
| 151 |
+
if not model_info.is_loaded:
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
logger.info(f"🗑️ Unloading model: {name}")
|
| 156 |
+
|
| 157 |
+
# Delete model instance
|
| 158 |
+
if model_info.model_instance is not None:
|
| 159 |
+
del model_info.model_instance
|
| 160 |
+
|
| 161 |
+
model_info.model_instance = None
|
| 162 |
+
model_info.is_loaded = False
|
| 163 |
+
|
| 164 |
+
# Cleanup
|
| 165 |
+
gc.collect()
|
| 166 |
+
if torch.cuda.is_available():
|
| 167 |
+
torch.cuda.empty_cache()
|
| 168 |
+
|
| 169 |
+
logger.info(f"✅ Model '{name}' unloaded")
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"❌ Error unloading model '{name}': {e}")
|
| 173 |
+
|
| 174 |
+
def check_memory_pressure(self) -> bool:
|
| 175 |
+
"""
|
| 176 |
+
Check GPU memory usage and unload least-used non-critical models if needed.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
True if cleanup was performed
|
| 180 |
+
"""
|
| 181 |
+
if not torch.cuda.is_available():
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 185 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 186 |
+
usage_ratio = allocated / total
|
| 187 |
+
|
| 188 |
+
if usage_ratio < self._memory_threshold:
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
logger.warning(f"⚠️ Memory pressure detected: {usage_ratio:.1%} used")
|
| 192 |
+
|
| 193 |
+
# Find non-critical models sorted by last used time
|
| 194 |
+
unloadable = [
|
| 195 |
+
(name, info) for name, info in self._models.items()
|
| 196 |
+
if info.is_loaded and not info.is_critical
|
| 197 |
+
]
|
| 198 |
+
unloadable.sort(key=lambda x: x[1].last_used)
|
| 199 |
+
|
| 200 |
+
# Unload oldest non-critical models
|
| 201 |
+
cleaned = False
|
| 202 |
+
for name, info in unloadable:
|
| 203 |
+
self.unload_model(name)
|
| 204 |
+
cleaned = True
|
| 205 |
+
|
| 206 |
+
# Re-check memory
|
| 207 |
+
new_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory
|
| 208 |
+
if new_ratio < self._memory_threshold * 0.7: # Target 70% of threshold
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
return cleaned
|
| 212 |
+
|
| 213 |
+
def force_cleanup(self):
|
| 214 |
+
"""Force cleanup all non-critical models and clear caches."""
|
| 215 |
+
logger.info("🧹 Force cleanup initiated")
|
| 216 |
+
|
| 217 |
+
# Unload all non-critical models
|
| 218 |
+
for name, info in self._models.items():
|
| 219 |
+
if info.is_loaded and not info.is_critical:
|
| 220 |
+
self.unload_model(name)
|
| 221 |
+
|
| 222 |
+
# Aggressive garbage collection
|
| 223 |
+
for _ in range(5):
|
| 224 |
+
gc.collect()
|
| 225 |
+
|
| 226 |
+
if torch.cuda.is_available():
|
| 227 |
+
torch.cuda.empty_cache()
|
| 228 |
+
torch.cuda.ipc_collect()
|
| 229 |
+
torch.cuda.synchronize()
|
| 230 |
+
|
| 231 |
+
logger.info("✅ Force cleanup completed")
|
| 232 |
+
|
| 233 |
+
def get_memory_status(self) -> Dict[str, Any]:
|
| 234 |
+
"""
|
| 235 |
+
Get detailed memory status.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Dictionary with memory statistics
|
| 239 |
+
"""
|
| 240 |
+
status = {
|
| 241 |
+
"device": self._device,
|
| 242 |
+
"models": {},
|
| 243 |
+
"total_estimated_gb": 0.0
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
# Model status
|
| 247 |
+
for name, info in self._models.items():
|
| 248 |
+
status["models"][name] = {
|
| 249 |
+
"loaded": info.is_loaded,
|
| 250 |
+
"critical": info.is_critical,
|
| 251 |
+
"estimated_gb": info.estimated_memory_gb,
|
| 252 |
+
"last_used": info.last_used
|
| 253 |
+
}
|
| 254 |
+
if info.is_loaded:
|
| 255 |
+
status["total_estimated_gb"] += info.estimated_memory_gb
|
| 256 |
+
|
| 257 |
+
# GPU memory
|
| 258 |
+
if torch.cuda.is_available():
|
| 259 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 260 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 261 |
+
cached = torch.cuda.memory_reserved() / 1024**3
|
| 262 |
+
|
| 263 |
+
status["gpu"] = {
|
| 264 |
+
"allocated_gb": round(allocated, 2),
|
| 265 |
+
"total_gb": round(total, 2),
|
| 266 |
+
"cached_gb": round(cached, 2),
|
| 267 |
+
"free_gb": round(total - allocated, 2),
|
| 268 |
+
"usage_percent": round((allocated / total) * 100, 1)
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
return status
|
| 272 |
+
|
| 273 |
+
def get_loaded_models(self) -> list:
|
| 274 |
+
"""Get list of currently loaded model names."""
|
| 275 |
+
return [name for name, info in self._models.items() if info.is_loaded]
|
| 276 |
+
|
| 277 |
+
def is_model_loaded(self, name: str) -> bool:
|
| 278 |
+
"""Check if a specific model is loaded."""
|
| 279 |
+
if name not in self._models:
|
| 280 |
+
return False
|
| 281 |
+
return self._models[name].is_loaded
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# Global singleton instance
|
| 285 |
+
_model_manager = None
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_model_manager() -> ModelManager:
|
| 289 |
+
"""Get the global ModelManager singleton instance."""
|
| 290 |
+
global _model_manager
|
| 291 |
+
if _model_manager is None:
|
| 292 |
+
_model_manager = ModelManager()
|
| 293 |
+
return _model_manager
|
quality_checker.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from typing import Dict, Any, Tuple, Optional
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
logger.setLevel(logging.INFO)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class QualityResult:
|
| 14 |
+
"""Result of a quality check."""
|
| 15 |
+
score: float # 0-100
|
| 16 |
+
passed: bool
|
| 17 |
+
issue: str
|
| 18 |
+
details: Dict[str, Any]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class QualityChecker:
|
| 22 |
+
"""
|
| 23 |
+
Automated quality validation system for generated images.
|
| 24 |
+
Provides checks for mask coverage, edge continuity, and color harmony.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Quality thresholds
|
| 28 |
+
THRESHOLD_PASS = 70
|
| 29 |
+
THRESHOLD_WARNING = 50
|
| 30 |
+
|
| 31 |
+
def __init__(self, strictness: str = "standard"):
|
| 32 |
+
"""
|
| 33 |
+
Initialize QualityChecker.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
strictness: Quality check strictness level
|
| 37 |
+
"lenient" - Only check fatal issues
|
| 38 |
+
"standard" - All checks with moderate thresholds
|
| 39 |
+
"strict" - High standards required
|
| 40 |
+
"""
|
| 41 |
+
self.strictness = strictness
|
| 42 |
+
self._set_thresholds()
|
| 43 |
+
|
| 44 |
+
def _set_thresholds(self):
|
| 45 |
+
"""Set quality thresholds based on strictness level."""
|
| 46 |
+
if self.strictness == "lenient":
|
| 47 |
+
self.min_coverage = 0.03 # 3%
|
| 48 |
+
self.min_edge_score = 40
|
| 49 |
+
self.min_harmony_score = 40
|
| 50 |
+
elif self.strictness == "strict":
|
| 51 |
+
self.min_coverage = 0.10 # 10%
|
| 52 |
+
self.min_edge_score = 75
|
| 53 |
+
self.min_harmony_score = 75
|
| 54 |
+
else: # standard
|
| 55 |
+
self.min_coverage = 0.05 # 5%
|
| 56 |
+
self.min_edge_score = 60
|
| 57 |
+
self.min_harmony_score = 60
|
| 58 |
+
|
| 59 |
+
def check_mask_coverage(self, mask: Image.Image) -> QualityResult:
|
| 60 |
+
"""
|
| 61 |
+
Verify mask coverage is adequate.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
mask: Grayscale mask image (L mode)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
QualityResult with coverage analysis
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
mask_array = np.array(mask.convert('L'))
|
| 71 |
+
height, width = mask_array.shape
|
| 72 |
+
total_pixels = height * width
|
| 73 |
+
|
| 74 |
+
# Count foreground pixels
|
| 75 |
+
fg_pixels = np.count_nonzero(mask_array > 127)
|
| 76 |
+
coverage_ratio = fg_pixels / total_pixels
|
| 77 |
+
|
| 78 |
+
# Check for isolated small regions (noise)
|
| 79 |
+
_, binary = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY)
|
| 80 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 81 |
+
|
| 82 |
+
# Count significant regions (> 1% of image)
|
| 83 |
+
min_region_size = total_pixels * 0.01
|
| 84 |
+
significant_regions = sum(1 for i in range(1, num_labels)
|
| 85 |
+
if stats[i, cv2.CC_STAT_AREA] > min_region_size)
|
| 86 |
+
|
| 87 |
+
# Calculate fragmentation (many small regions = bad)
|
| 88 |
+
fragmentation_penalty = max(0, (num_labels - 1 - significant_regions) * 2)
|
| 89 |
+
|
| 90 |
+
# Score calculation
|
| 91 |
+
coverage_score = min(100, coverage_ratio * 200) # 50% coverage = 100 score
|
| 92 |
+
final_score = max(0, coverage_score - fragmentation_penalty)
|
| 93 |
+
|
| 94 |
+
# Determine pass/fail
|
| 95 |
+
passed = coverage_ratio >= self.min_coverage and significant_regions >= 1
|
| 96 |
+
issue = ""
|
| 97 |
+
|
| 98 |
+
if coverage_ratio < self.min_coverage:
|
| 99 |
+
issue = f"Low foreground coverage ({coverage_ratio:.1%})"
|
| 100 |
+
elif significant_regions == 0:
|
| 101 |
+
issue = "No significant foreground regions detected"
|
| 102 |
+
elif fragmentation_penalty > 20:
|
| 103 |
+
issue = f"Fragmented mask with {num_labels - 1} isolated regions"
|
| 104 |
+
|
| 105 |
+
return QualityResult(
|
| 106 |
+
score=final_score,
|
| 107 |
+
passed=passed,
|
| 108 |
+
issue=issue,
|
| 109 |
+
details={
|
| 110 |
+
"coverage_ratio": coverage_ratio,
|
| 111 |
+
"foreground_pixels": fg_pixels,
|
| 112 |
+
"total_regions": num_labels - 1,
|
| 113 |
+
"significant_regions": significant_regions
|
| 114 |
+
}
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"❌ Mask coverage check failed: {e}")
|
| 119 |
+
return QualityResult(score=0, passed=False, issue=str(e), details={})
|
| 120 |
+
|
| 121 |
+
def check_edge_continuity(self, mask: Image.Image) -> QualityResult:
|
| 122 |
+
"""
|
| 123 |
+
Check if mask edges are continuous and smooth.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
mask: Grayscale mask image
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
QualityResult with edge analysis
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
mask_array = np.array(mask.convert('L'))
|
| 133 |
+
|
| 134 |
+
# Find edges using morphological gradient
|
| 135 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 136 |
+
gradient = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, kernel)
|
| 137 |
+
|
| 138 |
+
# Get edge pixels
|
| 139 |
+
edge_pixels = gradient > 20
|
| 140 |
+
edge_count = np.count_nonzero(edge_pixels)
|
| 141 |
+
|
| 142 |
+
if edge_count == 0:
|
| 143 |
+
return QualityResult(
|
| 144 |
+
score=50,
|
| 145 |
+
passed=False,
|
| 146 |
+
issue="No edges detected in mask",
|
| 147 |
+
details={"edge_count": 0}
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Check edge smoothness using Laplacian
|
| 151 |
+
laplacian = cv2.Laplacian(mask_array, cv2.CV_64F)
|
| 152 |
+
edge_laplacian = np.abs(laplacian[edge_pixels])
|
| 153 |
+
|
| 154 |
+
# High Laplacian values indicate jagged edges
|
| 155 |
+
smoothness = 100 - min(100, np.std(edge_laplacian) * 0.5)
|
| 156 |
+
|
| 157 |
+
# Check for gaps in edges
|
| 158 |
+
# Dilate and erode to find disconnections
|
| 159 |
+
dilated = cv2.dilate(gradient, kernel, iterations=1)
|
| 160 |
+
eroded = cv2.erode(dilated, kernel, iterations=1)
|
| 161 |
+
gaps = cv2.subtract(dilated, eroded)
|
| 162 |
+
gap_ratio = np.count_nonzero(gaps) / max(edge_count, 1)
|
| 163 |
+
|
| 164 |
+
# Calculate final score
|
| 165 |
+
gap_penalty = min(40, gap_ratio * 100)
|
| 166 |
+
final_score = max(0, smoothness - gap_penalty)
|
| 167 |
+
|
| 168 |
+
passed = final_score >= self.min_edge_score
|
| 169 |
+
issue = ""
|
| 170 |
+
|
| 171 |
+
if final_score < self.min_edge_score:
|
| 172 |
+
if smoothness < 60:
|
| 173 |
+
issue = "Jagged or rough edges detected"
|
| 174 |
+
elif gap_ratio > 0.3:
|
| 175 |
+
issue = "Discontinuous edges with gaps"
|
| 176 |
+
else:
|
| 177 |
+
issue = "Poor edge quality"
|
| 178 |
+
|
| 179 |
+
return QualityResult(
|
| 180 |
+
score=final_score,
|
| 181 |
+
passed=passed,
|
| 182 |
+
issue=issue,
|
| 183 |
+
details={
|
| 184 |
+
"edge_count": edge_count,
|
| 185 |
+
"smoothness": smoothness,
|
| 186 |
+
"gap_ratio": gap_ratio
|
| 187 |
+
}
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.error(f"❌ Edge continuity check failed: {e}")
|
| 192 |
+
return QualityResult(score=0, passed=False, issue=str(e), details={})
|
| 193 |
+
|
| 194 |
+
def check_color_harmony(
|
| 195 |
+
self,
|
| 196 |
+
foreground: Image.Image,
|
| 197 |
+
background: Image.Image,
|
| 198 |
+
mask: Image.Image
|
| 199 |
+
) -> QualityResult:
|
| 200 |
+
"""
|
| 201 |
+
Evaluate color harmony between foreground and background.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
foreground: Original foreground image
|
| 205 |
+
background: Generated background image
|
| 206 |
+
mask: Combination mask
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
QualityResult with harmony analysis
|
| 210 |
+
"""
|
| 211 |
+
try:
|
| 212 |
+
fg_array = np.array(foreground.convert('RGB'))
|
| 213 |
+
bg_array = np.array(background.convert('RGB'))
|
| 214 |
+
mask_array = np.array(mask.convert('L'))
|
| 215 |
+
|
| 216 |
+
# Get foreground and background regions
|
| 217 |
+
fg_region = mask_array > 127
|
| 218 |
+
bg_region = mask_array <= 127
|
| 219 |
+
|
| 220 |
+
if not np.any(fg_region) or not np.any(bg_region):
|
| 221 |
+
return QualityResult(
|
| 222 |
+
score=50,
|
| 223 |
+
passed=True,
|
| 224 |
+
issue="Cannot analyze harmony - insufficient regions",
|
| 225 |
+
details={}
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Convert to LAB for perceptual analysis
|
| 229 |
+
fg_lab = cv2.cvtColor(fg_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 230 |
+
bg_lab = cv2.cvtColor(bg_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 231 |
+
|
| 232 |
+
# Calculate average colors
|
| 233 |
+
fg_avg_l = np.mean(fg_lab[fg_region, 0])
|
| 234 |
+
fg_avg_a = np.mean(fg_lab[fg_region, 1])
|
| 235 |
+
fg_avg_b = np.mean(fg_lab[fg_region, 2])
|
| 236 |
+
|
| 237 |
+
bg_avg_l = np.mean(bg_lab[bg_region, 0])
|
| 238 |
+
bg_avg_a = np.mean(bg_lab[bg_region, 1])
|
| 239 |
+
bg_avg_b = np.mean(bg_lab[bg_region, 2])
|
| 240 |
+
|
| 241 |
+
# Calculate color differences
|
| 242 |
+
delta_l = abs(fg_avg_l - bg_avg_l)
|
| 243 |
+
delta_a = abs(fg_avg_a - bg_avg_a)
|
| 244 |
+
delta_b = abs(fg_avg_b - bg_avg_b)
|
| 245 |
+
|
| 246 |
+
# Overall color difference (Delta E approximation)
|
| 247 |
+
delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
|
| 248 |
+
|
| 249 |
+
# Score calculation
|
| 250 |
+
# Moderate difference is good (20-60 Delta E)
|
| 251 |
+
# Too similar or too different is problematic
|
| 252 |
+
if delta_e < 10:
|
| 253 |
+
harmony_score = 60 # Too similar, foreground may get lost
|
| 254 |
+
issue = "Foreground and background colors too similar"
|
| 255 |
+
elif delta_e > 80:
|
| 256 |
+
harmony_score = 50 # Too different, may look unnatural
|
| 257 |
+
issue = "High color contrast may look unnatural"
|
| 258 |
+
elif 20 <= delta_e <= 60:
|
| 259 |
+
harmony_score = 100 # Ideal range
|
| 260 |
+
issue = ""
|
| 261 |
+
else:
|
| 262 |
+
harmony_score = 80
|
| 263 |
+
issue = ""
|
| 264 |
+
|
| 265 |
+
# Check for extreme contrast (very dark fg on very bright bg or vice versa)
|
| 266 |
+
brightness_contrast = abs(fg_avg_l - bg_avg_l)
|
| 267 |
+
if brightness_contrast > 100:
|
| 268 |
+
harmony_score = max(40, harmony_score - 30)
|
| 269 |
+
issue = "Extreme brightness contrast between foreground and background"
|
| 270 |
+
|
| 271 |
+
passed = harmony_score >= self.min_harmony_score
|
| 272 |
+
|
| 273 |
+
return QualityResult(
|
| 274 |
+
score=harmony_score,
|
| 275 |
+
passed=passed,
|
| 276 |
+
issue=issue,
|
| 277 |
+
details={
|
| 278 |
+
"delta_e": delta_e,
|
| 279 |
+
"delta_l": delta_l,
|
| 280 |
+
"delta_a": delta_a,
|
| 281 |
+
"delta_b": delta_b,
|
| 282 |
+
"fg_luminance": fg_avg_l,
|
| 283 |
+
"bg_luminance": bg_avg_l
|
| 284 |
+
}
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error(f"❌ Color harmony check failed: {e}")
|
| 289 |
+
return QualityResult(score=0, passed=False, issue=str(e), details={})
|
| 290 |
+
|
| 291 |
+
def run_all_checks(
|
| 292 |
+
self,
|
| 293 |
+
foreground: Image.Image,
|
| 294 |
+
background: Image.Image,
|
| 295 |
+
mask: Image.Image,
|
| 296 |
+
combined: Optional[Image.Image] = None
|
| 297 |
+
) -> Dict[str, Any]:
|
| 298 |
+
"""
|
| 299 |
+
Run all quality checks and return comprehensive results.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
foreground: Original foreground image
|
| 303 |
+
background: Generated background
|
| 304 |
+
mask: Combination mask
|
| 305 |
+
combined: Final combined image (optional)
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Dictionary with all check results and overall score
|
| 309 |
+
"""
|
| 310 |
+
logger.info("🔍 Running quality checks...")
|
| 311 |
+
|
| 312 |
+
results = {
|
| 313 |
+
"checks": {},
|
| 314 |
+
"overall_score": 0,
|
| 315 |
+
"passed": True,
|
| 316 |
+
"warnings": [],
|
| 317 |
+
"errors": []
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
# Run individual checks
|
| 321 |
+
coverage_result = self.check_mask_coverage(mask)
|
| 322 |
+
results["checks"]["mask_coverage"] = {
|
| 323 |
+
"score": coverage_result.score,
|
| 324 |
+
"passed": coverage_result.passed,
|
| 325 |
+
"issue": coverage_result.issue,
|
| 326 |
+
"details": coverage_result.details
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
edge_result = self.check_edge_continuity(mask)
|
| 330 |
+
results["checks"]["edge_continuity"] = {
|
| 331 |
+
"score": edge_result.score,
|
| 332 |
+
"passed": edge_result.passed,
|
| 333 |
+
"issue": edge_result.issue,
|
| 334 |
+
"details": edge_result.details
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
harmony_result = self.check_color_harmony(foreground, background, mask)
|
| 338 |
+
results["checks"]["color_harmony"] = {
|
| 339 |
+
"score": harmony_result.score,
|
| 340 |
+
"passed": harmony_result.passed,
|
| 341 |
+
"issue": harmony_result.issue,
|
| 342 |
+
"details": harmony_result.details
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
# Calculate overall score (weighted average)
|
| 346 |
+
weights = {
|
| 347 |
+
"mask_coverage": 0.4,
|
| 348 |
+
"edge_continuity": 0.3,
|
| 349 |
+
"color_harmony": 0.3
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
total_score = (
|
| 353 |
+
coverage_result.score * weights["mask_coverage"] +
|
| 354 |
+
edge_result.score * weights["edge_continuity"] +
|
| 355 |
+
harmony_result.score * weights["color_harmony"]
|
| 356 |
+
)
|
| 357 |
+
results["overall_score"] = round(total_score, 1)
|
| 358 |
+
|
| 359 |
+
# Determine overall pass/fail
|
| 360 |
+
results["passed"] = all([
|
| 361 |
+
coverage_result.passed,
|
| 362 |
+
edge_result.passed,
|
| 363 |
+
harmony_result.passed
|
| 364 |
+
])
|
| 365 |
+
|
| 366 |
+
# Collect warnings and errors
|
| 367 |
+
for check_name, check_data in results["checks"].items():
|
| 368 |
+
if check_data["issue"]:
|
| 369 |
+
if check_data["passed"]:
|
| 370 |
+
results["warnings"].append(f"{check_name}: {check_data['issue']}")
|
| 371 |
+
else:
|
| 372 |
+
results["errors"].append(f"{check_name}: {check_data['issue']}")
|
| 373 |
+
|
| 374 |
+
logger.info(f"📊 Quality check complete - Score: {results['overall_score']}, Passed: {results['passed']}")
|
| 375 |
+
|
| 376 |
+
return results
|
| 377 |
+
|
| 378 |
+
def get_quality_summary(self, results: Dict[str, Any]) -> str:
|
| 379 |
+
"""
|
| 380 |
+
Generate human-readable quality summary.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
results: Results from run_all_checks
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Summary string
|
| 387 |
+
"""
|
| 388 |
+
score = results["overall_score"]
|
| 389 |
+
passed = results["passed"]
|
| 390 |
+
|
| 391 |
+
if score >= 90:
|
| 392 |
+
grade = "Excellent"
|
| 393 |
+
elif score >= 75:
|
| 394 |
+
grade = "Good"
|
| 395 |
+
elif score >= 60:
|
| 396 |
+
grade = "Acceptable"
|
| 397 |
+
elif score >= 40:
|
| 398 |
+
grade = "Needs Improvement"
|
| 399 |
+
else:
|
| 400 |
+
grade = "Poor"
|
| 401 |
+
|
| 402 |
+
summary = f"Quality: {grade} ({score:.0f}/100)"
|
| 403 |
+
|
| 404 |
+
if results["errors"]:
|
| 405 |
+
summary += f"\nIssues: {'; '.join(results['errors'])}"
|
| 406 |
+
elif results["warnings"]:
|
| 407 |
+
summary += f"\nNotes: {'; '.join(results['warnings'])}"
|
| 408 |
+
|
| 409 |
+
return summary
|
requirements.txt
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SceneWeaver Hugging Face Spaces Deployment Requirements
|
| 2 |
+
# Optimized for ZeroGPU environment with safe version ranges
|
| 3 |
+
|
| 4 |
+
# ============================================
|
| 5 |
+
# Core Deep Learning Framework
|
| 6 |
+
# ============================================
|
| 7 |
+
# PyTorch 2.x series - compatible with SDXL and xformers
|
| 8 |
+
torch>=2.0.0,<2.5.0
|
| 9 |
+
torchvision>=0.15.0,<0.20.0
|
| 10 |
+
torchaudio>=2.0.0,<2.5.0
|
| 11 |
+
|
| 12 |
+
# ============================================
|
| 13 |
+
# Diffusion Models and Transformers
|
| 14 |
+
# ============================================
|
| 15 |
+
# Diffusers 0.25+ has better SDXL support, <0.32 for stability
|
| 16 |
+
diffusers>=0.25.0,<0.32.0
|
| 17 |
+
# Transformers compatible with diffusers and open_clip
|
| 18 |
+
transformers>=4.35.0,<4.46.0
|
| 19 |
+
# Accelerate for model loading optimizations
|
| 20 |
+
accelerate>=0.24.0,<0.35.0
|
| 21 |
+
# xformers for memory efficient attention (optional, may fail on some systems)
|
| 22 |
+
# xformers>=0.0.22,<0.0.29
|
| 23 |
+
|
| 24 |
+
# ============================================
|
| 25 |
+
# Computer Vision and Image Processing
|
| 26 |
+
# ============================================
|
| 27 |
+
# OpenCV for image processing
|
| 28 |
+
opencv-python>=4.8.0,<4.11.0
|
| 29 |
+
# opencv-contrib-python for guided filter (cv2.ximgproc)
|
| 30 |
+
opencv-contrib-python>=4.8.0,<4.11.0
|
| 31 |
+
# Pillow for image I/O
|
| 32 |
+
Pillow>=9.5.0,<11.0.0
|
| 33 |
+
# SciPy for scientific computing
|
| 34 |
+
scipy>=1.10.0,<1.15.0
|
| 35 |
+
|
| 36 |
+
# ============================================
|
| 37 |
+
# Background Removal
|
| 38 |
+
# ============================================
|
| 39 |
+
# rembg for foreground extraction (CPU version for compatibility)
|
| 40 |
+
rembg>=2.0.50,<2.1.0
|
| 41 |
+
|
| 42 |
+
# ============================================
|
| 43 |
+
# Multi-modal Understanding (CLIP)
|
| 44 |
+
# ============================================
|
| 45 |
+
# OpenCLIP for image analysis
|
| 46 |
+
open_clip_torch>=2.20.0,<2.27.0
|
| 47 |
+
# Sentence transformers (dependency)
|
| 48 |
+
sentence-transformers>=2.2.0,<3.1.0
|
| 49 |
+
|
| 50 |
+
# ============================================
|
| 51 |
+
# Web Interface
|
| 52 |
+
# ============================================
|
| 53 |
+
# Gradio 4.x for modern UI
|
| 54 |
+
gradio>=4.0.0,<5.0.0
|
| 55 |
+
|
| 56 |
+
# ============================================
|
| 57 |
+
# Core Scientific Computing
|
| 58 |
+
# ============================================
|
| 59 |
+
# NumPy 1.x for compatibility
|
| 60 |
+
numpy>=1.24.0,<2.0.0
|
| 61 |
+
|
| 62 |
+
# ============================================
|
| 63 |
+
# Hugging Face Integration
|
| 64 |
+
# ============================================
|
| 65 |
+
# HuggingFace Hub for model downloads
|
| 66 |
+
huggingface_hub>=0.19.0,<0.27.0
|
| 67 |
+
# Safetensors for efficient model loading
|
| 68 |
+
safetensors>=0.4.0,<0.5.0
|
| 69 |
+
|
| 70 |
+
# ============================================
|
| 71 |
+
# System Utilities
|
| 72 |
+
# ============================================
|
| 73 |
+
# psutil for memory monitoring
|
| 74 |
+
psutil>=5.9.0,<6.1.0
|
| 75 |
+
# requests for HTTP operations
|
| 76 |
+
requests>=2.28.0,<2.33.0
|
| 77 |
+
|
| 78 |
+
# ============================================
|
| 79 |
+
# Hugging Face Spaces (auto-installed on Spaces)
|
| 80 |
+
# ============================================
|
| 81 |
+
# spaces # ZeroGPU support - auto-available on HF Spaces
|
scene_templates.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class SceneTemplate:
|
| 10 |
+
"""Data class representing a scene template"""
|
| 11 |
+
key: str
|
| 12 |
+
name: str
|
| 13 |
+
prompt: str
|
| 14 |
+
negative_extra: str
|
| 15 |
+
category: str
|
| 16 |
+
icon: str
|
| 17 |
+
guidance_scale: float = 7.5
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SceneTemplateManager:
|
| 21 |
+
"""
|
| 22 |
+
Manages curated scene templates for background generation.
|
| 23 |
+
Provides categorized presets that users can select with one click.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# Scene template definitions
|
| 27 |
+
TEMPLATES: Dict[str, SceneTemplate] = {
|
| 28 |
+
# === Professional Category ===
|
| 29 |
+
"office_modern": SceneTemplate(
|
| 30 |
+
key="office_modern",
|
| 31 |
+
name="Modern Office",
|
| 32 |
+
prompt="modern minimalist office interior, clean white desk, large floor-to-ceiling windows, natural daylight, professional corporate environment, soft shadows, contemporary furniture",
|
| 33 |
+
negative_extra="messy, cluttered, dark, old",
|
| 34 |
+
category="Professional",
|
| 35 |
+
icon="🏢",
|
| 36 |
+
guidance_scale=7.5
|
| 37 |
+
),
|
| 38 |
+
"office_executive": SceneTemplate(
|
| 39 |
+
key="office_executive",
|
| 40 |
+
name="Executive Suite",
|
| 41 |
+
prompt="luxurious executive office, mahogany desk, leather chair, city skyline view through windows, warm ambient lighting, bookshelf, elegant professional setting",
|
| 42 |
+
negative_extra="cheap, cramped, messy",
|
| 43 |
+
category="Professional",
|
| 44 |
+
icon="👔",
|
| 45 |
+
guidance_scale=7.5
|
| 46 |
+
),
|
| 47 |
+
"studio_white": SceneTemplate(
|
| 48 |
+
key="studio_white",
|
| 49 |
+
name="White Studio",
|
| 50 |
+
prompt="clean white photography studio background, professional lighting setup, seamless white backdrop, soft diffused light, minimal shadows",
|
| 51 |
+
negative_extra="colored, textured, dirty",
|
| 52 |
+
category="Professional",
|
| 53 |
+
icon="📷",
|
| 54 |
+
guidance_scale=8.0
|
| 55 |
+
),
|
| 56 |
+
"coworking": SceneTemplate(
|
| 57 |
+
key="coworking",
|
| 58 |
+
name="Coworking Space",
|
| 59 |
+
prompt="modern coworking space, open plan office, plants, exposed brick, industrial chic design, natural light, collaborative environment",
|
| 60 |
+
negative_extra="empty, dark, boring",
|
| 61 |
+
category="Professional",
|
| 62 |
+
icon="💼",
|
| 63 |
+
guidance_scale=7.0
|
| 64 |
+
),
|
| 65 |
+
"conference": SceneTemplate(
|
| 66 |
+
key="conference",
|
| 67 |
+
name="Conference Room",
|
| 68 |
+
prompt="modern conference room, large meeting table, glass walls, professional presentation screen, bright corporate lighting, clean minimal design",
|
| 69 |
+
negative_extra="small, cramped, outdated",
|
| 70 |
+
category="Professional",
|
| 71 |
+
icon="🤝",
|
| 72 |
+
guidance_scale=7.5
|
| 73 |
+
),
|
| 74 |
+
|
| 75 |
+
# === Nature Category ===
|
| 76 |
+
"beach_sunset": SceneTemplate(
|
| 77 |
+
key="beach_sunset",
|
| 78 |
+
name="Sunset Beach",
|
| 79 |
+
prompt="beautiful tropical beach at golden hour sunset, palm trees silhouette, calm turquoise ocean waves, warm orange and pink sky, soft sand, paradise vacation vibes",
|
| 80 |
+
negative_extra="storm, rain, crowded, trash",
|
| 81 |
+
category="Nature",
|
| 82 |
+
icon="🏖️",
|
| 83 |
+
guidance_scale=7.0
|
| 84 |
+
),
|
| 85 |
+
"forest_enchanted": SceneTemplate(
|
| 86 |
+
key="forest_enchanted",
|
| 87 |
+
name="Enchanted Forest",
|
| 88 |
+
prompt="magical enchanted forest, sunlight streaming through tall trees, lush green foliage, mystical atmosphere, morning mist, fairy tale woodland",
|
| 89 |
+
negative_extra="dead trees, dark, scary, barren",
|
| 90 |
+
category="Nature",
|
| 91 |
+
icon="🌲",
|
| 92 |
+
guidance_scale=7.0
|
| 93 |
+
),
|
| 94 |
+
"mountain_scenic": SceneTemplate(
|
| 95 |
+
key="mountain_scenic",
|
| 96 |
+
name="Mountain Vista",
|
| 97 |
+
prompt="breathtaking mountain landscape, snow-capped peaks, alpine meadow, clear blue sky, majestic scenic view, pristine nature, peaceful atmosphere",
|
| 98 |
+
negative_extra="industrial, polluted, crowded",
|
| 99 |
+
category="Nature",
|
| 100 |
+
icon="🏔️",
|
| 101 |
+
guidance_scale=7.5
|
| 102 |
+
),
|
| 103 |
+
"garden_spring": SceneTemplate(
|
| 104 |
+
key="garden_spring",
|
| 105 |
+
name="Spring Garden",
|
| 106 |
+
prompt="beautiful spring flower garden, colorful blooming flowers, roses and tulips, manicured hedges, sunny day, botanical paradise, fresh and vibrant",
|
| 107 |
+
negative_extra="dead, winter, wilted, dry",
|
| 108 |
+
category="Nature",
|
| 109 |
+
icon="🌸",
|
| 110 |
+
guidance_scale=7.0
|
| 111 |
+
),
|
| 112 |
+
"lake_serene": SceneTemplate(
|
| 113 |
+
key="lake_serene",
|
| 114 |
+
name="Serene Lake",
|
| 115 |
+
prompt="peaceful serene lake at dawn, mirror-like water reflection, surrounding mountains, soft morning light, tranquil atmosphere, pristine natural beauty",
|
| 116 |
+
negative_extra="stormy, polluted, industrial",
|
| 117 |
+
category="Nature",
|
| 118 |
+
icon="🏞️",
|
| 119 |
+
guidance_scale=7.0
|
| 120 |
+
),
|
| 121 |
+
"cherry_blossom": SceneTemplate(
|
| 122 |
+
key="cherry_blossom",
|
| 123 |
+
name="Cherry Blossom",
|
| 124 |
+
prompt="stunning cherry blossom trees in full bloom, pink sakura petals falling gently, Japanese garden aesthetic, soft spring sunlight, romantic atmosphere",
|
| 125 |
+
negative_extra="winter, dead, brown, wilted",
|
| 126 |
+
category="Nature",
|
| 127 |
+
icon="🌸",
|
| 128 |
+
guidance_scale=7.0
|
| 129 |
+
),
|
| 130 |
+
|
| 131 |
+
# === Urban Category ===
|
| 132 |
+
"city_skyline": SceneTemplate(
|
| 133 |
+
key="city_skyline",
|
| 134 |
+
name="City Skyline",
|
| 135 |
+
prompt="modern city skyline at blue hour, impressive skyscrapers, glass buildings reflecting sunset, urban metropolitan view, cinematic atmosphere",
|
| 136 |
+
negative_extra="slums, dirty, abandoned, ruins",
|
| 137 |
+
category="Urban",
|
| 138 |
+
icon="🌆",
|
| 139 |
+
guidance_scale=7.5
|
| 140 |
+
),
|
| 141 |
+
"cafe_cozy": SceneTemplate(
|
| 142 |
+
key="cafe_cozy",
|
| 143 |
+
name="Cozy Cafe",
|
| 144 |
+
prompt="warm cozy coffee shop interior, wooden furniture, ambient lighting, exposed brick walls, plants, comfortable atmosphere, artisan cafe vibes",
|
| 145 |
+
negative_extra="fast food, plastic, harsh lighting",
|
| 146 |
+
category="Urban",
|
| 147 |
+
icon="☕",
|
| 148 |
+
guidance_scale=7.0
|
| 149 |
+
),
|
| 150 |
+
"street_european": SceneTemplate(
|
| 151 |
+
key="street_european",
|
| 152 |
+
name="European Street",
|
| 153 |
+
prompt="charming European cobblestone street, historic buildings, outdoor cafe, flowers on balconies, warm afternoon light, romantic Paris or Rome vibes",
|
| 154 |
+
negative_extra="modern, industrial, ugly, dirty",
|
| 155 |
+
category="Urban",
|
| 156 |
+
icon="🏛️",
|
| 157 |
+
guidance_scale=7.0
|
| 158 |
+
),
|
| 159 |
+
"night_neon": SceneTemplate(
|
| 160 |
+
key="night_neon",
|
| 161 |
+
name="Neon Nightlife",
|
| 162 |
+
prompt="vibrant city nightlife scene, neon lights and signs, urban night atmosphere, colorful reflections on wet street, cyberpunk aesthetic, electric energy",
|
| 163 |
+
negative_extra="daytime, boring, plain",
|
| 164 |
+
category="Urban",
|
| 165 |
+
icon="🌃",
|
| 166 |
+
guidance_scale=6.5
|
| 167 |
+
),
|
| 168 |
+
"rooftop_view": SceneTemplate(
|
| 169 |
+
key="rooftop_view",
|
| 170 |
+
name="Rooftop Terrace",
|
| 171 |
+
prompt="luxury rooftop terrace, city panoramic view, modern outdoor furniture, string lights, sunset golden hour, sophisticated urban oasis",
|
| 172 |
+
negative_extra="cheap, dirty, crowded",
|
| 173 |
+
category="Urban",
|
| 174 |
+
icon="🏙️",
|
| 175 |
+
guidance_scale=7.5
|
| 176 |
+
),
|
| 177 |
+
|
| 178 |
+
# === Artistic Category ===
|
| 179 |
+
"gradient_soft": SceneTemplate(
|
| 180 |
+
key="gradient_soft",
|
| 181 |
+
name="Soft Gradient",
|
| 182 |
+
prompt="smooth soft gradient background, pastel colors blending beautifully, pink to blue to purple transition, dreamy aesthetic, professional portrait backdrop",
|
| 183 |
+
negative_extra="harsh, noisy, textured, busy",
|
| 184 |
+
category="Artistic",
|
| 185 |
+
icon="🎨",
|
| 186 |
+
guidance_scale=8.0
|
| 187 |
+
),
|
| 188 |
+
"abstract_modern": SceneTemplate(
|
| 189 |
+
key="abstract_modern",
|
| 190 |
+
name="Modern Abstract",
|
| 191 |
+
prompt="modern abstract art background, geometric shapes, bold colors, contemporary design, artistic composition, museum gallery aesthetic",
|
| 192 |
+
negative_extra="realistic, plain, boring",
|
| 193 |
+
category="Artistic",
|
| 194 |
+
icon="🖼️",
|
| 195 |
+
guidance_scale=6.5
|
| 196 |
+
),
|
| 197 |
+
"vintage_retro": SceneTemplate(
|
| 198 |
+
key="vintage_retro",
|
| 199 |
+
name="Vintage Retro",
|
| 200 |
+
prompt="vintage retro aesthetic background, warm sepia tones, nostalgic 70s vibes, film grain texture, classic photography style, timeless elegance",
|
| 201 |
+
negative_extra="modern, digital, cold, harsh",
|
| 202 |
+
category="Artistic",
|
| 203 |
+
icon="📻",
|
| 204 |
+
guidance_scale=7.0
|
| 205 |
+
),
|
| 206 |
+
"watercolor_dream": SceneTemplate(
|
| 207 |
+
key="watercolor_dream",
|
| 208 |
+
name="Watercolor Dream",
|
| 209 |
+
prompt="beautiful watercolor painting background, soft flowing colors, artistic brush strokes, dreamy ethereal atmosphere, delicate artistic aesthetic",
|
| 210 |
+
negative_extra="digital, sharp, photorealistic",
|
| 211 |
+
category="Artistic",
|
| 212 |
+
icon="🖌️",
|
| 213 |
+
guidance_scale=6.5
|
| 214 |
+
),
|
| 215 |
+
|
| 216 |
+
# === Seasonal Category ===
|
| 217 |
+
"autumn_foliage": SceneTemplate(
|
| 218 |
+
key="autumn_foliage",
|
| 219 |
+
name="Autumn Foliage",
|
| 220 |
+
prompt="beautiful autumn scenery, vibrant fall foliage, orange red and golden leaves, maple trees, warm sunlight filtering through, cozy seasonal atmosphere",
|
| 221 |
+
negative_extra="spring, summer, green, snow",
|
| 222 |
+
category="Seasonal",
|
| 223 |
+
icon="🍂",
|
| 224 |
+
guidance_scale=7.0
|
| 225 |
+
),
|
| 226 |
+
"winter_snow": SceneTemplate(
|
| 227 |
+
key="winter_snow",
|
| 228 |
+
name="Winter Wonderland",
|
| 229 |
+
prompt="magical winter wonderland, fresh white snow covering everything, snow-laden pine trees, soft snowfall, peaceful cold atmosphere, holiday season vibes",
|
| 230 |
+
negative_extra="summer, green, rain, mud",
|
| 231 |
+
category="Seasonal",
|
| 232 |
+
icon="❄️",
|
| 233 |
+
guidance_scale=7.0
|
| 234 |
+
),
|
| 235 |
+
"summer_tropical": SceneTemplate(
|
| 236 |
+
key="summer_tropical",
|
| 237 |
+
name="Tropical Summer",
|
| 238 |
+
prompt="vibrant tropical summer scene, lush palm trees, bright sunny day, exotic flowers, paradise vacation destination, warm and inviting atmosphere",
|
| 239 |
+
negative_extra="winter, cold, snow, gray",
|
| 240 |
+
category="Seasonal",
|
| 241 |
+
icon="🌴",
|
| 242 |
+
guidance_scale=7.0
|
| 243 |
+
),
|
| 244 |
+
"spring_meadow": SceneTemplate(
|
| 245 |
+
key="spring_meadow",
|
| 246 |
+
name="Spring Meadow",
|
| 247 |
+
prompt="beautiful spring meadow, wildflowers blooming, fresh green grass, butterflies, soft warm sunlight, renewal and new beginnings, pastoral beauty",
|
| 248 |
+
negative_extra="winter, autumn, dead, dry",
|
| 249 |
+
category="Seasonal",
|
| 250 |
+
icon="🌷",
|
| 251 |
+
guidance_scale=7.0
|
| 252 |
+
),
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# Category display order
|
| 256 |
+
CATEGORIES = ["Professional", "Nature", "Urban", "Artistic", "Seasonal"]
|
| 257 |
+
|
| 258 |
+
def __init__(self):
|
| 259 |
+
"""Initialize the scene template manager"""
|
| 260 |
+
logger.info(f"SceneTemplateManager initialized with {len(self.TEMPLATES)} templates")
|
| 261 |
+
|
| 262 |
+
def get_all_templates(self) -> Dict[str, SceneTemplate]:
|
| 263 |
+
"""Get all available templates"""
|
| 264 |
+
return self.TEMPLATES
|
| 265 |
+
|
| 266 |
+
def get_template(self, key: str) -> Optional[SceneTemplate]:
|
| 267 |
+
"""Get a specific template by key"""
|
| 268 |
+
return self.TEMPLATES.get(key)
|
| 269 |
+
|
| 270 |
+
def get_templates_by_category(self, category: str) -> List[SceneTemplate]:
|
| 271 |
+
"""Get all templates in a specific category"""
|
| 272 |
+
return [t for t in self.TEMPLATES.values() if t.category == category]
|
| 273 |
+
|
| 274 |
+
def get_categories(self) -> List[str]:
|
| 275 |
+
"""Get list of all categories in display order"""
|
| 276 |
+
return self.CATEGORIES
|
| 277 |
+
|
| 278 |
+
def get_template_choices_sorted(self) -> List[str]:
|
| 279 |
+
"""
|
| 280 |
+
Get template choices formatted for Gradio dropdown.
|
| 281 |
+
Returns list of display strings sorted A-Z: "🏢 Modern Office"
|
| 282 |
+
"""
|
| 283 |
+
display_list = []
|
| 284 |
+
for key, template in self.TEMPLATES.items():
|
| 285 |
+
display_name = f"{template.icon} {template.name}"
|
| 286 |
+
display_list.append(display_name)
|
| 287 |
+
|
| 288 |
+
# Sort alphabetically by name (ignoring emoji)
|
| 289 |
+
display_list.sort(key=lambda x: x.split(' ', 1)[1] if ' ' in x else x)
|
| 290 |
+
return display_list
|
| 291 |
+
|
| 292 |
+
def get_template_key_from_display(self, display_name: str) -> Optional[str]:
|
| 293 |
+
"""
|
| 294 |
+
Get template key from display name.
|
| 295 |
+
Example: "🏢 Modern Office" -> "office_modern"
|
| 296 |
+
"""
|
| 297 |
+
if not display_name:
|
| 298 |
+
return None
|
| 299 |
+
|
| 300 |
+
for key, template in self.TEMPLATES.items():
|
| 301 |
+
if f"{template.icon} {template.name}" == display_name:
|
| 302 |
+
return key
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
def get_prompt_for_template(self, key: str) -> Optional[str]:
|
| 306 |
+
"""Get the prompt string for a template"""
|
| 307 |
+
template = self.get_template(key)
|
| 308 |
+
return template.prompt if template else None
|
| 309 |
+
|
| 310 |
+
def get_negative_prompt_for_template(
|
| 311 |
+
self,
|
| 312 |
+
key: str,
|
| 313 |
+
base_negative: str = "blurry, low quality, distorted, people, characters"
|
| 314 |
+
) -> str:
|
| 315 |
+
"""Get combined negative prompt for a template"""
|
| 316 |
+
template = self.get_template(key)
|
| 317 |
+
if template and template.negative_extra:
|
| 318 |
+
return f"{base_negative}, {template.negative_extra}"
|
| 319 |
+
return base_negative
|
| 320 |
+
|
| 321 |
+
def get_guidance_scale_for_template(self, key: str) -> float:
|
| 322 |
+
"""Get the recommended guidance scale for a template"""
|
| 323 |
+
template = self.get_template(key)
|
| 324 |
+
return template.guidance_scale if template else 7.5
|
| 325 |
+
|
| 326 |
+
def build_gallery_html(self) -> str:
|
| 327 |
+
"""
|
| 328 |
+
Build HTML for the scene template gallery.
|
| 329 |
+
Returns HTML string for display in Gradio.
|
| 330 |
+
"""
|
| 331 |
+
html_parts = ['<div class="scene-gallery">']
|
| 332 |
+
|
| 333 |
+
for category in self.CATEGORIES:
|
| 334 |
+
templates = self.get_templates_by_category(category)
|
| 335 |
+
if not templates:
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
html_parts.append(f'''
|
| 339 |
+
<div class="scene-category">
|
| 340 |
+
<h4 class="scene-category-title">{category}</h4>
|
| 341 |
+
<div class="scene-grid">
|
| 342 |
+
''')
|
| 343 |
+
|
| 344 |
+
for template in templates:
|
| 345 |
+
html_parts.append(f'''
|
| 346 |
+
<button class="scene-card" data-template="{template.key}" onclick="selectTemplate('{template.key}')">
|
| 347 |
+
<span class="scene-icon">{template.icon}</span>
|
| 348 |
+
<span class="scene-name">{template.name}</span>
|
| 349 |
+
</button>
|
| 350 |
+
''')
|
| 351 |
+
|
| 352 |
+
html_parts.append('</div></div>')
|
| 353 |
+
|
| 354 |
+
html_parts.append('</div>')
|
| 355 |
+
return ''.join(html_parts)
|
| 356 |
+
|
| 357 |
+
def get_gallery_css(self) -> str:
|
| 358 |
+
"""Get CSS styles for the scene gallery"""
|
| 359 |
+
return """
|
| 360 |
+
/* Scene Gallery Styles */
|
| 361 |
+
.scene-gallery {
|
| 362 |
+
margin: 16px 0;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
.scene-category {
|
| 366 |
+
margin-bottom: 20px;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
.scene-category-title {
|
| 370 |
+
font-size: 0.9rem;
|
| 371 |
+
font-weight: 600;
|
| 372 |
+
color: #475569;
|
| 373 |
+
margin-bottom: 12px;
|
| 374 |
+
padding-bottom: 8px;
|
| 375 |
+
border-bottom: 1px solid #e2e8f0;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
.scene-grid {
|
| 379 |
+
display: grid;
|
| 380 |
+
grid-template-columns: repeat(auto-fill, minmax(100px, 1fr));
|
| 381 |
+
gap: 8px;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
.scene-card {
|
| 385 |
+
display: flex;
|
| 386 |
+
flex-direction: column;
|
| 387 |
+
align-items: center;
|
| 388 |
+
justify-content: center;
|
| 389 |
+
padding: 12px 8px;
|
| 390 |
+
background: #f8fafc;
|
| 391 |
+
border: 1px solid #e2e8f0;
|
| 392 |
+
border-radius: 8px;
|
| 393 |
+
cursor: pointer;
|
| 394 |
+
transition: all 0.2s ease;
|
| 395 |
+
min-height: 70px;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
.scene-card:hover {
|
| 399 |
+
background: #dbeafe;
|
| 400 |
+
border-color: #3b82f6;
|
| 401 |
+
transform: translateY(-2px);
|
| 402 |
+
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
.scene-card.selected {
|
| 406 |
+
background: #dbeafe;
|
| 407 |
+
border-color: #3b82f6;
|
| 408 |
+
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.3);
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
.scene-icon {
|
| 412 |
+
font-size: 1.5rem;
|
| 413 |
+
margin-bottom: 4px;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
.scene-name {
|
| 417 |
+
font-size: 0.75rem;
|
| 418 |
+
font-weight: 500;
|
| 419 |
+
color: #1e293b;
|
| 420 |
+
text-align: center;
|
| 421 |
+
line-height: 1.2;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
@media (max-width: 768px) {
|
| 425 |
+
.scene-grid {
|
| 426 |
+
grid-template-columns: repeat(3, 1fr);
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
"""
|
scene_weaver_core.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import logging
|
| 6 |
+
import gc
|
| 7 |
+
import time
|
| 8 |
+
from typing import Optional, Dict, Any, Tuple, List
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings("ignore")
|
| 12 |
+
|
| 13 |
+
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
| 14 |
+
import open_clip
|
| 15 |
+
from mask_generator import MaskGenerator
|
| 16 |
+
from image_blender import ImageBlender
|
| 17 |
+
from quality_checker import QualityChecker
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
logger.setLevel(logging.INFO)
|
| 21 |
+
|
| 22 |
+
class SceneWeaverCore:
|
| 23 |
+
"""
|
| 24 |
+
SceneWeaver with perfect background generation + fixed blending + memory optimization
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Style presets for diversity generation mode
|
| 28 |
+
STYLE_PRESETS = {
|
| 29 |
+
"professional": {
|
| 30 |
+
"name": "Professional Business",
|
| 31 |
+
"modifier": "professional office environment, clean background, corporate setting, bright even lighting",
|
| 32 |
+
"negative_extra": "casual, messy, cluttered",
|
| 33 |
+
"guidance_scale": 8.0
|
| 34 |
+
},
|
| 35 |
+
"casual": {
|
| 36 |
+
"name": "Casual Lifestyle",
|
| 37 |
+
"modifier": "casual outdoor setting, natural environment, relaxed atmosphere, warm natural lighting",
|
| 38 |
+
"negative_extra": "formal, studio",
|
| 39 |
+
"guidance_scale": 7.5
|
| 40 |
+
},
|
| 41 |
+
"artistic": {
|
| 42 |
+
"name": "Artistic Creative",
|
| 43 |
+
"modifier": "artistic background, creative composition, vibrant colors, interesting lighting",
|
| 44 |
+
"negative_extra": "boring, plain",
|
| 45 |
+
"guidance_scale": 6.5
|
| 46 |
+
},
|
| 47 |
+
"nature": {
|
| 48 |
+
"name": "Natural Scenery",
|
| 49 |
+
"modifier": "beautiful natural scenery, outdoor landscape, scenic view, natural lighting",
|
| 50 |
+
"negative_extra": "urban, indoor",
|
| 51 |
+
"guidance_scale": 7.5
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def __init__(self, device: str = "auto"):
|
| 56 |
+
self.device = self._setup_device(device)
|
| 57 |
+
|
| 58 |
+
# Model configurations - KEEP SAME FOR PERFECT GENERATION
|
| 59 |
+
self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 60 |
+
self.clip_model_name = "ViT-B-32"
|
| 61 |
+
self.clip_pretrained = "openai"
|
| 62 |
+
|
| 63 |
+
# Pipeline objects
|
| 64 |
+
self.pipeline = None
|
| 65 |
+
self.clip_model = None
|
| 66 |
+
self.clip_preprocess = None
|
| 67 |
+
self.clip_tokenizer = None
|
| 68 |
+
self.is_initialized = False
|
| 69 |
+
|
| 70 |
+
# Generation settings - KEEP SAME
|
| 71 |
+
self.max_image_size = 1024
|
| 72 |
+
self.default_steps = 25
|
| 73 |
+
self.use_fp16 = True
|
| 74 |
+
|
| 75 |
+
# Enhanced memory management
|
| 76 |
+
self.generation_count = 0
|
| 77 |
+
self.cleanup_frequency = 1 # More frequent cleanup
|
| 78 |
+
self.max_history = 3 # Limit generation history
|
| 79 |
+
|
| 80 |
+
# Initialize helper classes
|
| 81 |
+
self.mask_generator = MaskGenerator(self.max_image_size)
|
| 82 |
+
self.image_blender = ImageBlender()
|
| 83 |
+
self.quality_checker = QualityChecker()
|
| 84 |
+
|
| 85 |
+
logger.info(f"OptimizedSceneWeaver initialized on {self.device}")
|
| 86 |
+
|
| 87 |
+
def _setup_device(self, device: str) -> str:
|
| 88 |
+
"""Setup computation device"""
|
| 89 |
+
if device == "auto":
|
| 90 |
+
if torch.cuda.is_available():
|
| 91 |
+
return "cuda"
|
| 92 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 93 |
+
return "mps"
|
| 94 |
+
else:
|
| 95 |
+
return "cpu"
|
| 96 |
+
return device
|
| 97 |
+
|
| 98 |
+
def _ultra_memory_cleanup(self):
|
| 99 |
+
"""Ultra aggressive memory cleanup for Colab stability"""
|
| 100 |
+
logger.debug("🧹 Ultra memory cleanup...")
|
| 101 |
+
|
| 102 |
+
# Multiple rounds of garbage collection
|
| 103 |
+
for i in range(5):
|
| 104 |
+
gc.collect()
|
| 105 |
+
|
| 106 |
+
if torch.cuda.is_available():
|
| 107 |
+
# Clear all cached memory
|
| 108 |
+
torch.cuda.empty_cache()
|
| 109 |
+
torch.cuda.ipc_collect()
|
| 110 |
+
|
| 111 |
+
# Force synchronization
|
| 112 |
+
torch.cuda.synchronize()
|
| 113 |
+
|
| 114 |
+
# Clear any remaining memory fragments
|
| 115 |
+
try:
|
| 116 |
+
torch.cuda.memory.empty_cache()
|
| 117 |
+
except:
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
logger.debug("✅ Ultra cleanup completed")
|
| 121 |
+
|
| 122 |
+
def load_models(self, progress_callback: Optional[callable] = None):
|
| 123 |
+
"""Load AI models - KEEP SAME FOR PERFECT GENERATION"""
|
| 124 |
+
if self.is_initialized:
|
| 125 |
+
logger.info("Models already loaded")
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
logger.info("📥 Loading AI models...")
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
self._ultra_memory_cleanup()
|
| 132 |
+
|
| 133 |
+
if progress_callback:
|
| 134 |
+
progress_callback("Loading OpenCLIP for image understanding...", 20)
|
| 135 |
+
|
| 136 |
+
# Load OpenCLIP - KEEP SAME
|
| 137 |
+
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
|
| 138 |
+
self.clip_model_name,
|
| 139 |
+
pretrained=self.clip_pretrained,
|
| 140 |
+
device=self.device
|
| 141 |
+
)
|
| 142 |
+
self.clip_tokenizer = open_clip.get_tokenizer(self.clip_model_name)
|
| 143 |
+
self.clip_model.eval()
|
| 144 |
+
|
| 145 |
+
logger.info("✅ OpenCLIP loaded")
|
| 146 |
+
|
| 147 |
+
if progress_callback:
|
| 148 |
+
progress_callback("Loading SDXL text-to-image pipeline...", 60)
|
| 149 |
+
|
| 150 |
+
# Load standard SDXL text-to-image pipeline - KEEP SAME
|
| 151 |
+
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 152 |
+
self.base_model_id,
|
| 153 |
+
torch_dtype=torch.float16 if self.use_fp16 else torch.float32,
|
| 154 |
+
use_safetensors=True,
|
| 155 |
+
variant="fp16" if self.use_fp16 else None
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Use DPM solver for faster generation - KEEP SAME
|
| 159 |
+
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 160 |
+
self.pipeline.scheduler.config
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Move to device
|
| 164 |
+
self.pipeline = self.pipeline.to(self.device)
|
| 165 |
+
|
| 166 |
+
if progress_callback:
|
| 167 |
+
progress_callback("Applying optimizations...", 90)
|
| 168 |
+
|
| 169 |
+
# Memory optimizations - ENHANCED
|
| 170 |
+
try:
|
| 171 |
+
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 172 |
+
logger.info("✅ xformers enabled")
|
| 173 |
+
except Exception:
|
| 174 |
+
try:
|
| 175 |
+
self.pipeline.enable_attention_slicing()
|
| 176 |
+
logger.info("✅ Attention slicing enabled")
|
| 177 |
+
except Exception:
|
| 178 |
+
logger.warning("⚠️ No memory optimizations available")
|
| 179 |
+
|
| 180 |
+
# Additional memory optimizations
|
| 181 |
+
if hasattr(self.pipeline, 'enable_vae_tiling'):
|
| 182 |
+
self.pipeline.enable_vae_tiling()
|
| 183 |
+
|
| 184 |
+
if hasattr(self.pipeline, 'enable_vae_slicing'):
|
| 185 |
+
self.pipeline.enable_vae_slicing()
|
| 186 |
+
|
| 187 |
+
# Set to eval mode
|
| 188 |
+
self.pipeline.unet.eval()
|
| 189 |
+
if hasattr(self.pipeline, 'vae'):
|
| 190 |
+
self.pipeline.vae.eval()
|
| 191 |
+
|
| 192 |
+
# Enable sequential CPU offload if very low on memory
|
| 193 |
+
try:
|
| 194 |
+
if torch.cuda.is_available():
|
| 195 |
+
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
|
| 196 |
+
if free_memory < 4 * 1024**3: # Less than 4GB free
|
| 197 |
+
self.pipeline.enable_sequential_cpu_offload()
|
| 198 |
+
logger.info("✅ Sequential CPU offload enabled for low memory")
|
| 199 |
+
except:
|
| 200 |
+
pass
|
| 201 |
+
|
| 202 |
+
self.is_initialized = True
|
| 203 |
+
|
| 204 |
+
if progress_callback:
|
| 205 |
+
progress_callback("Models loaded successfully!", 100)
|
| 206 |
+
|
| 207 |
+
# Memory status
|
| 208 |
+
if torch.cuda.is_available():
|
| 209 |
+
memory_used = torch.cuda.memory_allocated() / 1024**3
|
| 210 |
+
memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 211 |
+
logger.info(f"📊 GPU Memory: {memory_used:.1f}GB / {memory_total:.1f}GB")
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.error(f"❌ Model loading failed: {e}")
|
| 215 |
+
raise RuntimeError(f"Failed to load models: {str(e)}")
|
| 216 |
+
|
| 217 |
+
def analyze_image_with_clip(self, image: Image.Image) -> str:
|
| 218 |
+
"""Analyze uploaded image using OpenCLIP - KEEP SAME"""
|
| 219 |
+
if not self.clip_model:
|
| 220 |
+
return "Image analysis not available"
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
|
| 224 |
+
|
| 225 |
+
categories = [
|
| 226 |
+
"a photo of a person",
|
| 227 |
+
"a photo of an animal",
|
| 228 |
+
"a photo of an object",
|
| 229 |
+
"a photo of a character",
|
| 230 |
+
"a photo of a cartoon",
|
| 231 |
+
"a photo of nature",
|
| 232 |
+
"a photo of a building",
|
| 233 |
+
"a photo of a landscape"
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
text_inputs = self.clip_tokenizer(categories).to(self.device)
|
| 237 |
+
|
| 238 |
+
with torch.no_grad():
|
| 239 |
+
image_features = self.clip_model.encode_image(image_input)
|
| 240 |
+
text_features = self.clip_model.encode_text(text_inputs)
|
| 241 |
+
|
| 242 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 243 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 244 |
+
|
| 245 |
+
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
| 246 |
+
|
| 247 |
+
best_match_idx = similarity.argmax().item()
|
| 248 |
+
confidence = similarity[0, best_match_idx].item()
|
| 249 |
+
|
| 250 |
+
category = categories[best_match_idx].replace("a photo of ", "")
|
| 251 |
+
|
| 252 |
+
return f"Detected: {category} (confidence: {confidence:.1%})"
|
| 253 |
+
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.error(f"CLIP analysis failed: {e}")
|
| 256 |
+
return "Image analysis failed"
|
| 257 |
+
|
| 258 |
+
def enhance_prompt(
|
| 259 |
+
self,
|
| 260 |
+
user_prompt: str,
|
| 261 |
+
foreground_image: Image.Image
|
| 262 |
+
) -> str:
|
| 263 |
+
"""
|
| 264 |
+
Smart prompt enhancement based on image analysis.
|
| 265 |
+
Adds appropriate lighting, atmosphere, and quality descriptors.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
user_prompt: Original user-provided prompt
|
| 269 |
+
foreground_image: Foreground image for analysis
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Enhanced prompt string
|
| 273 |
+
"""
|
| 274 |
+
logger.info("✨ Enhancing prompt based on image analysis...")
|
| 275 |
+
|
| 276 |
+
try:
|
| 277 |
+
# Analyze image characteristics
|
| 278 |
+
img_array = np.array(foreground_image.convert('RGB'))
|
| 279 |
+
|
| 280 |
+
# === Analyze color temperature ===
|
| 281 |
+
# Convert to LAB to analyze color temperature
|
| 282 |
+
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 283 |
+
avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
|
| 284 |
+
avg_b = np.mean(lab[:, :, 2]) # b channel: blue(-) to yellow(+)
|
| 285 |
+
|
| 286 |
+
# Determine warm/cool tone
|
| 287 |
+
is_warm = avg_b > 128 # b > 128 means more yellow/warm
|
| 288 |
+
|
| 289 |
+
# === Analyze brightness ===
|
| 290 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 291 |
+
avg_brightness = np.mean(gray)
|
| 292 |
+
is_bright = avg_brightness > 127
|
| 293 |
+
|
| 294 |
+
# === Get subject type from CLIP ===
|
| 295 |
+
clip_analysis = self.analyze_image_with_clip(foreground_image)
|
| 296 |
+
subject_type = "unknown"
|
| 297 |
+
|
| 298 |
+
if "person" in clip_analysis.lower():
|
| 299 |
+
subject_type = "person"
|
| 300 |
+
elif "animal" in clip_analysis.lower():
|
| 301 |
+
subject_type = "animal"
|
| 302 |
+
elif "object" in clip_analysis.lower():
|
| 303 |
+
subject_type = "object"
|
| 304 |
+
elif "character" in clip_analysis.lower() or "cartoon" in clip_analysis.lower():
|
| 305 |
+
subject_type = "character"
|
| 306 |
+
elif "nature" in clip_analysis.lower() or "landscape" in clip_analysis.lower():
|
| 307 |
+
subject_type = "nature"
|
| 308 |
+
|
| 309 |
+
# === Build prompt fragments library ===
|
| 310 |
+
lighting_options = {
|
| 311 |
+
"warm_bright": "warm golden hour lighting, soft natural light",
|
| 312 |
+
"warm_dark": "warm ambient lighting, cozy atmosphere",
|
| 313 |
+
"cool_bright": "bright daylight, clear sky lighting",
|
| 314 |
+
"cool_dark": "soft diffused light, gentle shadows"
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
atmosphere_options = {
|
| 318 |
+
"person": "professional, elegant composition",
|
| 319 |
+
"animal": "natural, harmonious setting",
|
| 320 |
+
"object": "clean product photography style",
|
| 321 |
+
"character": "artistic, vibrant, imaginative",
|
| 322 |
+
"nature": "scenic, peaceful atmosphere",
|
| 323 |
+
"unknown": "balanced composition"
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
|
| 327 |
+
|
| 328 |
+
# === Select appropriate fragments ===
|
| 329 |
+
# Lighting based on color temperature and brightness
|
| 330 |
+
if is_warm and is_bright:
|
| 331 |
+
lighting = lighting_options["warm_bright"]
|
| 332 |
+
elif is_warm and not is_bright:
|
| 333 |
+
lighting = lighting_options["warm_dark"]
|
| 334 |
+
elif not is_warm and is_bright:
|
| 335 |
+
lighting = lighting_options["cool_bright"]
|
| 336 |
+
else:
|
| 337 |
+
lighting = lighting_options["cool_dark"]
|
| 338 |
+
|
| 339 |
+
# Atmosphere based on subject type
|
| 340 |
+
atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
|
| 341 |
+
|
| 342 |
+
# === Check for conflicts in user prompt ===
|
| 343 |
+
user_prompt_lower = user_prompt.lower()
|
| 344 |
+
|
| 345 |
+
# Avoid adding conflicting descriptions
|
| 346 |
+
if "sunset" in user_prompt_lower or "golden" in user_prompt_lower:
|
| 347 |
+
lighting = "" # User already specified lighting
|
| 348 |
+
if "dark" in user_prompt_lower or "night" in user_prompt_lower:
|
| 349 |
+
lighting = lighting.replace("bright", "").replace("daylight", "")
|
| 350 |
+
|
| 351 |
+
# === Combine enhanced prompt ===
|
| 352 |
+
fragments = [user_prompt]
|
| 353 |
+
|
| 354 |
+
if lighting:
|
| 355 |
+
fragments.append(lighting)
|
| 356 |
+
if atmosphere:
|
| 357 |
+
fragments.append(atmosphere)
|
| 358 |
+
fragments.append(quality_modifiers)
|
| 359 |
+
|
| 360 |
+
enhanced_prompt = ", ".join(filter(None, fragments))
|
| 361 |
+
|
| 362 |
+
logger.info(f"📝 Original prompt: {user_prompt[:50]}...")
|
| 363 |
+
logger.info(f"📝 Enhanced prompt: {enhanced_prompt[:80]}...")
|
| 364 |
+
|
| 365 |
+
return enhanced_prompt
|
| 366 |
+
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.warning(f"⚠️ Prompt enhancement failed: {e}, using original prompt")
|
| 369 |
+
return user_prompt
|
| 370 |
+
|
| 371 |
+
def _prepare_image(self, image: Image.Image) -> Image.Image:
|
| 372 |
+
"""Prepare image for processing - KEEP SAME"""
|
| 373 |
+
# Convert to RGB
|
| 374 |
+
if image.mode != 'RGB':
|
| 375 |
+
image = image.convert('RGB')
|
| 376 |
+
|
| 377 |
+
# Resize if too large
|
| 378 |
+
width, height = image.size
|
| 379 |
+
max_size = self.max_image_size
|
| 380 |
+
|
| 381 |
+
if width > max_size or height > max_size:
|
| 382 |
+
ratio = min(max_size/width, max_size/height)
|
| 383 |
+
new_width = int(width * ratio)
|
| 384 |
+
new_height = int(height * ratio)
|
| 385 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 386 |
+
|
| 387 |
+
# Ensure dimensions are multiple of 8
|
| 388 |
+
width, height = image.size
|
| 389 |
+
new_width = (width // 8) * 8
|
| 390 |
+
new_height = (height // 8) * 8
|
| 391 |
+
|
| 392 |
+
if new_width != width or new_height != height:
|
| 393 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 394 |
+
|
| 395 |
+
return image
|
| 396 |
+
|
| 397 |
+
def generate_background(
|
| 398 |
+
self,
|
| 399 |
+
prompt: str,
|
| 400 |
+
width: int,
|
| 401 |
+
height: int,
|
| 402 |
+
negative_prompt: str = "blurry, low quality, distorted",
|
| 403 |
+
num_inference_steps: int = 25,
|
| 404 |
+
guidance_scale: float = 7.5,
|
| 405 |
+
progress_callback: Optional[callable] = None
|
| 406 |
+
) -> Image.Image:
|
| 407 |
+
"""Generate complete background using standard text-to-image - KEEP SAME"""
|
| 408 |
+
|
| 409 |
+
if not self.is_initialized:
|
| 410 |
+
raise RuntimeError("Models not loaded. Call load_models() first.")
|
| 411 |
+
|
| 412 |
+
logger.info(f"🎨 Generating background: {prompt[:50]}...")
|
| 413 |
+
|
| 414 |
+
try:
|
| 415 |
+
with torch.inference_mode():
|
| 416 |
+
if progress_callback:
|
| 417 |
+
progress_callback("Generating background with SDXL...", 50)
|
| 418 |
+
|
| 419 |
+
# Standard text-to-image generation - KEEP SAME
|
| 420 |
+
result = self.pipeline(
|
| 421 |
+
prompt=prompt,
|
| 422 |
+
negative_prompt=negative_prompt,
|
| 423 |
+
width=width,
|
| 424 |
+
height=height,
|
| 425 |
+
num_inference_steps=num_inference_steps,
|
| 426 |
+
guidance_scale=guidance_scale,
|
| 427 |
+
generator=torch.Generator(device=self.device).manual_seed(42)
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
generated_image = result.images[0]
|
| 431 |
+
|
| 432 |
+
if progress_callback:
|
| 433 |
+
progress_callback("Background generated successfully!", 100)
|
| 434 |
+
|
| 435 |
+
logger.info("✅ Background generation completed!")
|
| 436 |
+
return generated_image
|
| 437 |
+
|
| 438 |
+
except torch.cuda.OutOfMemoryError:
|
| 439 |
+
logger.error("❌ GPU memory exhausted")
|
| 440 |
+
self._ultra_memory_cleanup()
|
| 441 |
+
raise RuntimeError("GPU memory insufficient")
|
| 442 |
+
|
| 443 |
+
except Exception as e:
|
| 444 |
+
logger.error(f"❌ Background generation failed: {e}")
|
| 445 |
+
raise RuntimeError(f"Generation failed: {str(e)}")
|
| 446 |
+
|
| 447 |
+
def generate_and_combine(
|
| 448 |
+
self,
|
| 449 |
+
original_image: Image.Image,
|
| 450 |
+
prompt: str,
|
| 451 |
+
combination_mode: str = "center",
|
| 452 |
+
focus_mode: str = "person",
|
| 453 |
+
negative_prompt: str = "blurry, low quality, distorted",
|
| 454 |
+
num_inference_steps: int = 25,
|
| 455 |
+
guidance_scale: float = 7.5,
|
| 456 |
+
progress_callback: Optional[callable] = None,
|
| 457 |
+
enable_prompt_enhancement: bool = True
|
| 458 |
+
) -> Dict[str, Any]:
|
| 459 |
+
"""
|
| 460 |
+
Generate background and combine with foreground using advanced blending.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
original_image: Foreground image
|
| 464 |
+
prompt: User's background description
|
| 465 |
+
combination_mode: How to position foreground ("center", "left_half", "right_half", "full")
|
| 466 |
+
focus_mode: Focus type ("person" for tight crop, "scene" for wider context)
|
| 467 |
+
negative_prompt: What to avoid in generation
|
| 468 |
+
num_inference_steps: SDXL inference steps
|
| 469 |
+
guidance_scale: Classifier-free guidance scale
|
| 470 |
+
progress_callback: Progress reporting callback
|
| 471 |
+
enable_prompt_enhancement: Whether to use smart prompt enhancement
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
Dictionary containing results and metadata
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
if not self.is_initialized:
|
| 478 |
+
raise RuntimeError("Models not loaded. Call load_models() first.")
|
| 479 |
+
|
| 480 |
+
logger.info(f"🎨 Starting generation and combination with advanced features...")
|
| 481 |
+
|
| 482 |
+
try:
|
| 483 |
+
# Enhanced memory management
|
| 484 |
+
if self.generation_count % self.cleanup_frequency == 0:
|
| 485 |
+
self._ultra_memory_cleanup()
|
| 486 |
+
|
| 487 |
+
if progress_callback:
|
| 488 |
+
progress_callback("Analyzing uploaded image...", 5)
|
| 489 |
+
|
| 490 |
+
# Analyze original image
|
| 491 |
+
image_analysis = self.analyze_image_with_clip(original_image)
|
| 492 |
+
|
| 493 |
+
if progress_callback:
|
| 494 |
+
progress_callback("Preparing images...", 10)
|
| 495 |
+
|
| 496 |
+
# Prepare original image
|
| 497 |
+
processed_original = self._prepare_image(original_image)
|
| 498 |
+
target_width, target_height = processed_original.size
|
| 499 |
+
|
| 500 |
+
if progress_callback:
|
| 501 |
+
progress_callback("Optimizing prompt...", 15)
|
| 502 |
+
|
| 503 |
+
# Smart prompt enhancement
|
| 504 |
+
if enable_prompt_enhancement:
|
| 505 |
+
enhanced_prompt = self.enhance_prompt(prompt, processed_original)
|
| 506 |
+
else:
|
| 507 |
+
enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic, beautiful scenery"
|
| 508 |
+
|
| 509 |
+
enhanced_negative = f"{negative_prompt}, people, characters, cartoons, logos"
|
| 510 |
+
|
| 511 |
+
if progress_callback:
|
| 512 |
+
progress_callback("Generating complete background scene...", 25)
|
| 513 |
+
|
| 514 |
+
def bg_progress(msg, pct):
|
| 515 |
+
if progress_callback:
|
| 516 |
+
progress_callback(f"Background: {msg}", 25 + (pct/100) * 50)
|
| 517 |
+
|
| 518 |
+
generated_background = self.generate_background(
|
| 519 |
+
prompt=enhanced_prompt,
|
| 520 |
+
width=target_width,
|
| 521 |
+
height=target_height,
|
| 522 |
+
negative_prompt=enhanced_negative,
|
| 523 |
+
num_inference_steps=num_inference_steps,
|
| 524 |
+
guidance_scale=guidance_scale,
|
| 525 |
+
progress_callback=bg_progress
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if progress_callback:
|
| 529 |
+
progress_callback("Creating intelligent mask for person detection...", 80)
|
| 530 |
+
|
| 531 |
+
# Use intelligent mask generation with enhanced logging
|
| 532 |
+
logger.info("🎭 Starting intelligent mask generation...")
|
| 533 |
+
combination_mask = self.mask_generator.create_gradient_based_mask(
|
| 534 |
+
processed_original,
|
| 535 |
+
combination_mode,
|
| 536 |
+
focus_mode
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Log mask quality for debugging
|
| 540 |
+
try:
|
| 541 |
+
mask_array = np.array(combination_mask)
|
| 542 |
+
logger.info(f"📊 Generated mask stats - Mean: {mask_array.mean():.1f}, Non-zero pixels: {np.count_nonzero(mask_array)}")
|
| 543 |
+
except Exception as mask_debug_error:
|
| 544 |
+
logger.warning(f"⚠️ Mask debug logging failed: {mask_debug_error}")
|
| 545 |
+
|
| 546 |
+
if progress_callback:
|
| 547 |
+
progress_callback("Advanced image blending...", 90)
|
| 548 |
+
|
| 549 |
+
# Use advanced image blending with logging
|
| 550 |
+
logger.info("🖌️ Starting advanced image blending...")
|
| 551 |
+
combined_image = self.image_blender.simple_blend_images(
|
| 552 |
+
processed_original,
|
| 553 |
+
generated_background,
|
| 554 |
+
combination_mask
|
| 555 |
+
)
|
| 556 |
+
logger.info("✅ Image blending completed successfully")
|
| 557 |
+
|
| 558 |
+
if progress_callback:
|
| 559 |
+
progress_callback("Creating debug images...", 95)
|
| 560 |
+
|
| 561 |
+
# Generate debug images
|
| 562 |
+
debug_images = self.image_blender.create_debug_images(
|
| 563 |
+
processed_original,
|
| 564 |
+
generated_background,
|
| 565 |
+
combination_mask,
|
| 566 |
+
combined_image
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# Memory cleanup after generation
|
| 570 |
+
self._ultra_memory_cleanup()
|
| 571 |
+
|
| 572 |
+
# Update generation count
|
| 573 |
+
self.generation_count += 1
|
| 574 |
+
|
| 575 |
+
if progress_callback:
|
| 576 |
+
progress_callback("Generation complete!", 100)
|
| 577 |
+
|
| 578 |
+
logger.info("✅ Complete generation and combination with fixed blending successful!")
|
| 579 |
+
|
| 580 |
+
return {
|
| 581 |
+
"combined_image": combined_image,
|
| 582 |
+
"generated_scene": generated_background,
|
| 583 |
+
"original_image": processed_original,
|
| 584 |
+
"combination_mask": combination_mask,
|
| 585 |
+
"debug_mask_gray": debug_images["mask_gray"],
|
| 586 |
+
"debug_alpha_heatmap": debug_images["alpha_heatmap"],
|
| 587 |
+
"image_analysis": image_analysis,
|
| 588 |
+
"enhanced_prompt": enhanced_prompt,
|
| 589 |
+
"original_prompt": prompt,
|
| 590 |
+
"success": True,
|
| 591 |
+
"generation_count": self.generation_count
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
except Exception as e:
|
| 595 |
+
import traceback
|
| 596 |
+
error_traceback = traceback.format_exc()
|
| 597 |
+
logger.error(f"❌ Generation and combination failed: {str(e)}")
|
| 598 |
+
logger.error(f"📍 Full traceback:\n{error_traceback}")
|
| 599 |
+
print(f"❌ DETAILED ERROR in scene_weaver_core.generate_and_combine:")
|
| 600 |
+
print(f"Error: {str(e)}")
|
| 601 |
+
print(f"Traceback:\n{error_traceback}")
|
| 602 |
+
self._ultra_memory_cleanup() # Cleanup on error too
|
| 603 |
+
return {
|
| 604 |
+
"success": False,
|
| 605 |
+
"error": f"Failed: {str(e)}"
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
def generate_diversity_variants(
|
| 609 |
+
self,
|
| 610 |
+
original_image: Image.Image,
|
| 611 |
+
prompt: str,
|
| 612 |
+
selected_styles: Optional[List[str]] = None,
|
| 613 |
+
combination_mode: str = "center",
|
| 614 |
+
focus_mode: str = "person",
|
| 615 |
+
negative_prompt: str = "blurry, low quality, distorted",
|
| 616 |
+
progress_callback: Optional[callable] = None
|
| 617 |
+
) -> Dict[str, Any]:
|
| 618 |
+
"""
|
| 619 |
+
Generate multiple style variants of the background.
|
| 620 |
+
Uses reduced quality for faster preview generation.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
original_image: Foreground image
|
| 624 |
+
prompt: Base background description
|
| 625 |
+
selected_styles: List of style keys to use (None = all styles)
|
| 626 |
+
combination_mode: Foreground positioning mode
|
| 627 |
+
focus_mode: Focus type for mask generation
|
| 628 |
+
negative_prompt: Base negative prompt
|
| 629 |
+
progress_callback: Progress callback function
|
| 630 |
+
|
| 631 |
+
Returns:
|
| 632 |
+
Dictionary containing variants and metadata
|
| 633 |
+
"""
|
| 634 |
+
if not self.is_initialized:
|
| 635 |
+
raise RuntimeError("Models not loaded. Call load_models() first.")
|
| 636 |
+
|
| 637 |
+
logger.info("🎨 Starting diversity generation mode...")
|
| 638 |
+
|
| 639 |
+
# Determine which styles to generate
|
| 640 |
+
styles_to_generate = selected_styles or list(self.STYLE_PRESETS.keys())
|
| 641 |
+
num_styles = len(styles_to_generate)
|
| 642 |
+
|
| 643 |
+
results = {
|
| 644 |
+
"variants": [],
|
| 645 |
+
"success": True,
|
| 646 |
+
"num_variants": 0
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
try:
|
| 650 |
+
# Pre-process image once
|
| 651 |
+
processed_original = self._prepare_image(original_image)
|
| 652 |
+
target_width, target_height = processed_original.size
|
| 653 |
+
|
| 654 |
+
# Reduce resolution for faster generation
|
| 655 |
+
preview_size = min(768, max(target_width, target_height))
|
| 656 |
+
scale = preview_size / max(target_width, target_height)
|
| 657 |
+
preview_width = int(target_width * scale) // 8 * 8
|
| 658 |
+
preview_height = int(target_height * scale) // 8 * 8
|
| 659 |
+
|
| 660 |
+
# Generate mask once (reusable for all variants)
|
| 661 |
+
if progress_callback:
|
| 662 |
+
progress_callback("Creating foreground mask...", 5)
|
| 663 |
+
|
| 664 |
+
combination_mask = self.mask_generator.create_gradient_based_mask(
|
| 665 |
+
processed_original, combination_mode, focus_mode
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# Resize mask for preview resolution
|
| 669 |
+
preview_mask = combination_mask.resize((preview_width, preview_height), Image.LANCZOS)
|
| 670 |
+
preview_original = processed_original.resize((preview_width, preview_height), Image.LANCZOS)
|
| 671 |
+
|
| 672 |
+
# Generate each style variant
|
| 673 |
+
for idx, style_key in enumerate(styles_to_generate):
|
| 674 |
+
if style_key not in self.STYLE_PRESETS:
|
| 675 |
+
logger.warning(f"⚠️ Unknown style: {style_key}, skipping")
|
| 676 |
+
continue
|
| 677 |
+
|
| 678 |
+
style = self.STYLE_PRESETS[style_key]
|
| 679 |
+
style_name = style["name"]
|
| 680 |
+
|
| 681 |
+
if progress_callback:
|
| 682 |
+
base_pct = 10 + (idx / num_styles) * 80
|
| 683 |
+
progress_callback(f"Generating {style_name} variant...", int(base_pct))
|
| 684 |
+
|
| 685 |
+
logger.info(f"🎨 Generating variant: {style_name}")
|
| 686 |
+
|
| 687 |
+
try:
|
| 688 |
+
# Build style-specific prompt
|
| 689 |
+
styled_prompt = f"{prompt}, {style['modifier']}, high quality, detailed"
|
| 690 |
+
styled_negative = f"{negative_prompt}, {style['negative_extra']}, people, characters"
|
| 691 |
+
|
| 692 |
+
# Generate background with reduced steps for speed
|
| 693 |
+
background = self.generate_background(
|
| 694 |
+
prompt=styled_prompt,
|
| 695 |
+
width=preview_width,
|
| 696 |
+
height=preview_height,
|
| 697 |
+
negative_prompt=styled_negative,
|
| 698 |
+
num_inference_steps=15, # Reduced for speed
|
| 699 |
+
guidance_scale=style["guidance_scale"]
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Blend images
|
| 703 |
+
combined = self.image_blender.simple_blend_images(
|
| 704 |
+
preview_original,
|
| 705 |
+
background,
|
| 706 |
+
preview_mask,
|
| 707 |
+
use_multi_scale=False # Skip for speed
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
results["variants"].append({
|
| 711 |
+
"style_key": style_key,
|
| 712 |
+
"style_name": style_name,
|
| 713 |
+
"combined_image": combined,
|
| 714 |
+
"background": background,
|
| 715 |
+
"prompt_used": styled_prompt
|
| 716 |
+
})
|
| 717 |
+
|
| 718 |
+
# Memory cleanup between variants
|
| 719 |
+
self._ultra_memory_cleanup()
|
| 720 |
+
|
| 721 |
+
except Exception as variant_error:
|
| 722 |
+
logger.error(f"❌ Failed to generate {style_name} variant: {variant_error}")
|
| 723 |
+
continue
|
| 724 |
+
|
| 725 |
+
results["num_variants"] = len(results["variants"])
|
| 726 |
+
|
| 727 |
+
if progress_callback:
|
| 728 |
+
progress_callback("Diversity generation complete!", 100)
|
| 729 |
+
|
| 730 |
+
logger.info(f"✅ Generated {results['num_variants']} style variants")
|
| 731 |
+
return results
|
| 732 |
+
|
| 733 |
+
except Exception as e:
|
| 734 |
+
logger.error(f"❌ Diversity generation failed: {e}")
|
| 735 |
+
self._ultra_memory_cleanup()
|
| 736 |
+
return {
|
| 737 |
+
"variants": [],
|
| 738 |
+
"success": False,
|
| 739 |
+
"error": str(e),
|
| 740 |
+
"num_variants": 0
|
| 741 |
+
}
|
| 742 |
+
|
| 743 |
+
def regenerate_high_quality(
|
| 744 |
+
self,
|
| 745 |
+
original_image: Image.Image,
|
| 746 |
+
prompt: str,
|
| 747 |
+
style_key: str,
|
| 748 |
+
combination_mode: str = "center",
|
| 749 |
+
focus_mode: str = "person",
|
| 750 |
+
negative_prompt: str = "blurry, low quality, distorted",
|
| 751 |
+
progress_callback: Optional[callable] = None
|
| 752 |
+
) -> Dict[str, Any]:
|
| 753 |
+
"""
|
| 754 |
+
Regenerate a specific style at full quality.
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
original_image: Original foreground image
|
| 758 |
+
prompt: Base prompt
|
| 759 |
+
style_key: Style preset key to use
|
| 760 |
+
combination_mode: Foreground positioning
|
| 761 |
+
focus_mode: Mask focus mode
|
| 762 |
+
negative_prompt: Base negative prompt
|
| 763 |
+
progress_callback: Progress callback
|
| 764 |
+
|
| 765 |
+
Returns:
|
| 766 |
+
Full quality result dictionary
|
| 767 |
+
"""
|
| 768 |
+
if style_key not in self.STYLE_PRESETS:
|
| 769 |
+
return {"success": False, "error": f"Unknown style: {style_key}"}
|
| 770 |
+
|
| 771 |
+
style = self.STYLE_PRESETS[style_key]
|
| 772 |
+
|
| 773 |
+
# Build styled prompt
|
| 774 |
+
styled_prompt = f"{prompt}, {style['modifier']}"
|
| 775 |
+
styled_negative = f"{negative_prompt}, {style['negative_extra']}"
|
| 776 |
+
|
| 777 |
+
# Use full generate_and_combine with style parameters
|
| 778 |
+
return self.generate_and_combine(
|
| 779 |
+
original_image=original_image,
|
| 780 |
+
prompt=styled_prompt,
|
| 781 |
+
combination_mode=combination_mode,
|
| 782 |
+
focus_mode=focus_mode,
|
| 783 |
+
negative_prompt=styled_negative,
|
| 784 |
+
num_inference_steps=25, # Full quality
|
| 785 |
+
guidance_scale=style["guidance_scale"],
|
| 786 |
+
progress_callback=progress_callback,
|
| 787 |
+
enable_prompt_enhancement=True
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
def get_memory_status(self) -> Dict[str, Any]:
|
| 791 |
+
"""Enhanced memory status reporting"""
|
| 792 |
+
status = {"device": self.device}
|
| 793 |
+
|
| 794 |
+
if torch.cuda.is_available():
|
| 795 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 796 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 797 |
+
cached = torch.cuda.memory_reserved() / 1024**3
|
| 798 |
+
|
| 799 |
+
status.update({
|
| 800 |
+
"gpu_allocated_gb": round(allocated, 2),
|
| 801 |
+
"gpu_total_gb": round(total, 2),
|
| 802 |
+
"gpu_cached_gb": round(cached, 2),
|
| 803 |
+
"gpu_free_gb": round(total - allocated, 2),
|
| 804 |
+
"gpu_usage_percent": round((allocated / total) * 100, 1),
|
| 805 |
+
"generation_count": self.generation_count
|
| 806 |
+
})
|
| 807 |
+
|
| 808 |
+
return status
|
ui_manager.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import spaces
|
| 10 |
+
|
| 11 |
+
from scene_weaver_core import SceneWeaverCore
|
| 12 |
+
from css_styles import CSSStyles
|
| 13 |
+
from scene_templates import SceneTemplateManager
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
logger.setLevel(logging.INFO)
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
level=logging.INFO,
|
| 20 |
+
format='%(asctime)s [%(name)s] %(levelname)s: %(message)s',
|
| 21 |
+
datefmt='%H:%M:%S'
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class UIManager:
|
| 26 |
+
"""Gradio UI with enhanced memory management and professional design"""
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.sceneweaver = SceneWeaverCore()
|
| 30 |
+
self.template_manager = SceneTemplateManager()
|
| 31 |
+
self.generation_history = []
|
| 32 |
+
self._preview_sensitivity = 0.5
|
| 33 |
+
|
| 34 |
+
def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]:
|
| 35 |
+
"""
|
| 36 |
+
Apply a scene template to the prompt fields.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
display_name: The display name from dropdown (e.g., "🏢 Modern Office")
|
| 40 |
+
current_negative: Current negative prompt value
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Tuple of (prompt, negative_prompt, guidance_scale)
|
| 44 |
+
"""
|
| 45 |
+
if not display_name:
|
| 46 |
+
return "", current_negative, 7.5
|
| 47 |
+
|
| 48 |
+
# Convert display name to template key
|
| 49 |
+
template_key = self.template_manager.get_template_key_from_display(display_name)
|
| 50 |
+
if not template_key:
|
| 51 |
+
return "", current_negative, 7.5
|
| 52 |
+
|
| 53 |
+
template = self.template_manager.get_template(template_key)
|
| 54 |
+
if template:
|
| 55 |
+
prompt = template.prompt
|
| 56 |
+
negative = self.template_manager.get_negative_prompt_for_template(
|
| 57 |
+
template_key, current_negative
|
| 58 |
+
)
|
| 59 |
+
guidance = template.guidance_scale
|
| 60 |
+
return prompt, negative, guidance
|
| 61 |
+
|
| 62 |
+
return "", current_negative, 7.5
|
| 63 |
+
|
| 64 |
+
def quick_preview(
|
| 65 |
+
self,
|
| 66 |
+
uploaded_image: Optional[Image.Image],
|
| 67 |
+
sensitivity: float = 0.5
|
| 68 |
+
) -> Optional[Image.Image]:
|
| 69 |
+
"""
|
| 70 |
+
Generate quick foreground preview using lightweight traditional methods.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
uploaded_image: Uploaded PIL Image
|
| 74 |
+
sensitivity: Detection sensitivity (0.0 - 1.0)
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Preview image with colored overlay or None
|
| 78 |
+
"""
|
| 79 |
+
if uploaded_image is None:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
logger.info(f"Generating quick preview (sensitivity={sensitivity:.2f})")
|
| 84 |
+
|
| 85 |
+
img_array = np.array(uploaded_image.convert('RGB'))
|
| 86 |
+
height, width = img_array.shape[:2]
|
| 87 |
+
|
| 88 |
+
max_preview_size = 512
|
| 89 |
+
if max(width, height) > max_preview_size:
|
| 90 |
+
scale = max_preview_size / max(width, height)
|
| 91 |
+
new_w = int(width * scale)
|
| 92 |
+
new_h = int(height * scale)
|
| 93 |
+
img_array = cv2.resize(img_array, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 94 |
+
height, width = new_h, new_w
|
| 95 |
+
|
| 96 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 97 |
+
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
| 98 |
+
|
| 99 |
+
low_threshold = int(30 + (1 - sensitivity) * 50)
|
| 100 |
+
high_threshold = int(100 + (1 - sensitivity) * 100)
|
| 101 |
+
edges = cv2.Canny(blurred, low_threshold, high_threshold)
|
| 102 |
+
|
| 103 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 104 |
+
dilated = cv2.dilate(edges, kernel, iterations=2)
|
| 105 |
+
|
| 106 |
+
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 107 |
+
|
| 108 |
+
mask = np.zeros((height, width), dtype=np.uint8)
|
| 109 |
+
|
| 110 |
+
if contours:
|
| 111 |
+
sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
| 112 |
+
min_area = (width * height) * 0.01 * (1 - sensitivity)
|
| 113 |
+
for contour in sorted_contours:
|
| 114 |
+
if cv2.contourArea(contour) > min_area:
|
| 115 |
+
cv2.fillPoly(mask, [contour], 255)
|
| 116 |
+
|
| 117 |
+
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11))
|
| 118 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close)
|
| 119 |
+
|
| 120 |
+
overlay = img_array.copy().astype(np.float32)
|
| 121 |
+
|
| 122 |
+
fg_mask = mask > 127
|
| 123 |
+
overlay[fg_mask] = overlay[fg_mask] * 0.5 + np.array([0, 255, 0]) * 0.5
|
| 124 |
+
|
| 125 |
+
bg_mask = mask <= 127
|
| 126 |
+
overlay[bg_mask] = overlay[bg_mask] * 0.5 + np.array([255, 0, 0]) * 0.5
|
| 127 |
+
|
| 128 |
+
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
|
| 129 |
+
|
| 130 |
+
original_size = uploaded_image.size
|
| 131 |
+
preview_image = Image.fromarray(overlay)
|
| 132 |
+
if preview_image.size != original_size:
|
| 133 |
+
preview_image = preview_image.resize(original_size, Image.LANCZOS)
|
| 134 |
+
|
| 135 |
+
logger.info("Quick preview generated successfully")
|
| 136 |
+
return preview_image
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Quick preview failed: {e}")
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
def _save_result(self, combined_image: Image.Image, prompt: str):
|
| 143 |
+
"""Save result with memory-conscious history management"""
|
| 144 |
+
if not combined_image:
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
output_dir = Path("outputs")
|
| 148 |
+
output_dir.mkdir(exist_ok=True)
|
| 149 |
+
|
| 150 |
+
combined_image.save(output_dir / "latest_combined.png")
|
| 151 |
+
|
| 152 |
+
self.generation_history.append({
|
| 153 |
+
"prompt": prompt,
|
| 154 |
+
"timestamp": time.time()
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
max_history = self.sceneweaver.max_history
|
| 158 |
+
if len(self.generation_history) > max_history:
|
| 159 |
+
self.generation_history = self.generation_history[-max_history:]
|
| 160 |
+
|
| 161 |
+
@spaces.GPU(duration=120)
|
| 162 |
+
def generate_handler(
|
| 163 |
+
self,
|
| 164 |
+
uploaded_image: Optional[Image.Image],
|
| 165 |
+
prompt: str,
|
| 166 |
+
combination_mode: str,
|
| 167 |
+
focus_mode: str,
|
| 168 |
+
negative_prompt: str,
|
| 169 |
+
steps: int,
|
| 170 |
+
guidance: float,
|
| 171 |
+
progress=gr.Progress()
|
| 172 |
+
):
|
| 173 |
+
"""Enhanced generation handler with memory management and ZeroGPU support"""
|
| 174 |
+
|
| 175 |
+
if uploaded_image is None:
|
| 176 |
+
return None, None, None, "Please upload an image to get started!", gr.update(visible=False)
|
| 177 |
+
|
| 178 |
+
if not prompt.strip():
|
| 179 |
+
return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False)
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
if not self.sceneweaver.is_initialized:
|
| 183 |
+
progress(0.05, desc="Loading AI models (first time may take 2-3 minutes)...")
|
| 184 |
+
|
| 185 |
+
def init_progress(msg, pct):
|
| 186 |
+
if pct < 30:
|
| 187 |
+
desc = "Loading image analysis models..."
|
| 188 |
+
elif pct < 60:
|
| 189 |
+
desc = "Loading Stable Diffusion XL..."
|
| 190 |
+
elif pct < 90:
|
| 191 |
+
desc = "Applying memory optimizations..."
|
| 192 |
+
else:
|
| 193 |
+
desc = "Almost ready..."
|
| 194 |
+
progress(0.05 + (pct/100) * 0.2, desc=desc)
|
| 195 |
+
|
| 196 |
+
self.sceneweaver.load_models(progress_callback=init_progress)
|
| 197 |
+
|
| 198 |
+
def gen_progress(msg, pct):
|
| 199 |
+
if pct < 20:
|
| 200 |
+
desc = "Analyzing your image..."
|
| 201 |
+
elif pct < 50:
|
| 202 |
+
desc = "Generating background scene..."
|
| 203 |
+
elif pct < 80:
|
| 204 |
+
desc = "Blending foreground and background..."
|
| 205 |
+
elif pct < 95:
|
| 206 |
+
desc = "Applying final touches..."
|
| 207 |
+
else:
|
| 208 |
+
desc = "Complete!"
|
| 209 |
+
progress(0.25 + (pct/100) * 0.75, desc=desc)
|
| 210 |
+
|
| 211 |
+
result = self.sceneweaver.generate_and_combine(
|
| 212 |
+
original_image=uploaded_image,
|
| 213 |
+
prompt=prompt,
|
| 214 |
+
combination_mode=combination_mode,
|
| 215 |
+
focus_mode=focus_mode,
|
| 216 |
+
negative_prompt=negative_prompt,
|
| 217 |
+
num_inference_steps=int(steps),
|
| 218 |
+
guidance_scale=float(guidance),
|
| 219 |
+
progress_callback=gen_progress
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if result["success"]:
|
| 223 |
+
combined = result["combined_image"]
|
| 224 |
+
generated = result["generated_scene"]
|
| 225 |
+
original = result["original_image"]
|
| 226 |
+
|
| 227 |
+
self._save_result(combined, prompt)
|
| 228 |
+
|
| 229 |
+
status_msg = "Image created successfully!"
|
| 230 |
+
|
| 231 |
+
return combined, generated, original, status_msg, gr.update(visible=True)
|
| 232 |
+
else:
|
| 233 |
+
error_msg = result.get("error", "Something went wrong")
|
| 234 |
+
return None, None, None, f"Error: {error_msg}", gr.update(visible=False)
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
import traceback
|
| 238 |
+
error_traceback = traceback.format_exc()
|
| 239 |
+
logger.error(f"Generation handler error: {str(e)}")
|
| 240 |
+
logger.error(f"Traceback:\n{error_traceback}")
|
| 241 |
+
return None, None, None, f"Error: {str(e)}", gr.update(visible=False)
|
| 242 |
+
|
| 243 |
+
def create_interface(self):
|
| 244 |
+
"""Create professional user interface"""
|
| 245 |
+
|
| 246 |
+
css = CSSStyles.get_main_css()
|
| 247 |
+
|
| 248 |
+
with gr.Blocks(
|
| 249 |
+
css=css,
|
| 250 |
+
title="SceneWeaver - AI Background Generator",
|
| 251 |
+
theme=gr.themes.Soft()
|
| 252 |
+
) as interface:
|
| 253 |
+
|
| 254 |
+
# Header
|
| 255 |
+
gr.HTML("""
|
| 256 |
+
<div class="main-header">
|
| 257 |
+
<h1 class="main-title">
|
| 258 |
+
<span class="title-emoji">🎨</span>
|
| 259 |
+
SceneWeaver
|
| 260 |
+
</h1>
|
| 261 |
+
<p class="main-subtitle">AI-powered background generation with professional edge processing</p>
|
| 262 |
+
</div>
|
| 263 |
+
""")
|
| 264 |
+
|
| 265 |
+
with gr.Row():
|
| 266 |
+
# Left Column - Input controls
|
| 267 |
+
with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]):
|
| 268 |
+
gr.HTML("""
|
| 269 |
+
<div class="card-content">
|
| 270 |
+
<h3 class="card-title">
|
| 271 |
+
<span class="section-emoji">📸</span>
|
| 272 |
+
Upload & Generate
|
| 273 |
+
</h3>
|
| 274 |
+
</div>
|
| 275 |
+
""")
|
| 276 |
+
|
| 277 |
+
uploaded_image = gr.Image(
|
| 278 |
+
label="Upload Your Image",
|
| 279 |
+
type="pil",
|
| 280 |
+
height=280,
|
| 281 |
+
elem_classes=["input-field"]
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Scene Template Selector
|
| 285 |
+
with gr.Accordion("Scene Templates", open=False):
|
| 286 |
+
template_dropdown = gr.Dropdown(
|
| 287 |
+
label="Select a Scene",
|
| 288 |
+
choices=[""] + self.template_manager.get_template_choices_sorted(),
|
| 289 |
+
value="",
|
| 290 |
+
info="24 curated scenes sorted A-Z",
|
| 291 |
+
elem_classes=["template-dropdown"]
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
prompt_input = gr.Textbox(
|
| 295 |
+
label="Background Scene Description",
|
| 296 |
+
placeholder="Select a template above or describe your own scene...",
|
| 297 |
+
lines=3,
|
| 298 |
+
elem_classes=["input-field"]
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
combination_mode = gr.Dropdown(
|
| 302 |
+
label="Composition Mode",
|
| 303 |
+
choices=["center", "left_half", "right_half", "full"],
|
| 304 |
+
value="center",
|
| 305 |
+
info="center=Smart Center | left_half=Left Half | right_half=Right Half | full=Full Image",
|
| 306 |
+
elem_classes=["input-field"]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
focus_mode = gr.Dropdown(
|
| 310 |
+
label="Focus Mode",
|
| 311 |
+
choices=["person", "scene"],
|
| 312 |
+
value="person",
|
| 313 |
+
info="person=Tight Crop | scene=Include Surrounding Objects",
|
| 314 |
+
elem_classes=["input-field"]
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 318 |
+
negative_prompt = gr.Textbox(
|
| 319 |
+
label="Negative Prompt",
|
| 320 |
+
value="blurry, low quality, distorted, people, characters",
|
| 321 |
+
lines=2,
|
| 322 |
+
elem_classes=["input-field"]
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
steps_slider = gr.Slider(
|
| 326 |
+
label="Quality Steps",
|
| 327 |
+
minimum=15,
|
| 328 |
+
maximum=50,
|
| 329 |
+
value=25,
|
| 330 |
+
step=5,
|
| 331 |
+
elem_classes=["input-field"]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
guidance_slider = gr.Slider(
|
| 335 |
+
label="Guidance Scale",
|
| 336 |
+
minimum=5.0,
|
| 337 |
+
maximum=15.0,
|
| 338 |
+
value=7.5,
|
| 339 |
+
step=0.5,
|
| 340 |
+
elem_classes=["input-field"]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
generate_btn = gr.Button(
|
| 344 |
+
"Generate Background",
|
| 345 |
+
variant="primary",
|
| 346 |
+
size="lg",
|
| 347 |
+
elem_classes=["primary-button"]
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Right Column - Results display
|
| 351 |
+
with gr.Column(scale=2, elem_classes=["feature-card"], elem_id="results-gallery-centered"):
|
| 352 |
+
gr.HTML("""
|
| 353 |
+
<div class="card-content">
|
| 354 |
+
<h3 class="card-title">
|
| 355 |
+
<span class="section-emoji">🎭</span>
|
| 356 |
+
Results Gallery
|
| 357 |
+
</h3>
|
| 358 |
+
</div>
|
| 359 |
+
""")
|
| 360 |
+
|
| 361 |
+
# Loading notice
|
| 362 |
+
gr.HTML("""
|
| 363 |
+
<div class="loading-notice">
|
| 364 |
+
<span class="loading-notice-icon">⏱️</span>
|
| 365 |
+
<span class="loading-notice-text">
|
| 366 |
+
<strong>First-time users:</strong> Initial model loading takes 1-2 minutes.
|
| 367 |
+
Subsequent generations are much faster (~30s).
|
| 368 |
+
</span>
|
| 369 |
+
</div>
|
| 370 |
+
""")
|
| 371 |
+
|
| 372 |
+
# Quick start guide
|
| 373 |
+
gr.HTML("""
|
| 374 |
+
<details class="user-guidance-panel">
|
| 375 |
+
<summary class="guidance-summary">
|
| 376 |
+
<span class="emoji-enhanced">💡</span>
|
| 377 |
+
Quick Start Guide
|
| 378 |
+
</summary>
|
| 379 |
+
<div class="guidance-content">
|
| 380 |
+
<p><strong>Step 1:</strong> Upload any image with a clear subject</p>
|
| 381 |
+
<p><strong>Step 2:</strong> Describe or Choose your desired background scene</p>
|
| 382 |
+
<p><strong>Step 3:</strong> Choose composition mode (center works best)</p>
|
| 383 |
+
<p><strong>Step 4:</strong> Click Generate and wait for the magic!</p>
|
| 384 |
+
<p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p>
|
| 385 |
+
</div>
|
| 386 |
+
</details>
|
| 387 |
+
""")
|
| 388 |
+
|
| 389 |
+
with gr.Tabs():
|
| 390 |
+
with gr.TabItem("Final Result"):
|
| 391 |
+
combined_output = gr.Image(
|
| 392 |
+
label="Your Generated Image",
|
| 393 |
+
elem_classes=["result-gallery"],
|
| 394 |
+
show_label=False
|
| 395 |
+
)
|
| 396 |
+
with gr.TabItem("Background"):
|
| 397 |
+
generated_output = gr.Image(
|
| 398 |
+
label="Generated Background",
|
| 399 |
+
elem_classes=["result-gallery"],
|
| 400 |
+
show_label=False
|
| 401 |
+
)
|
| 402 |
+
with gr.TabItem("Original"):
|
| 403 |
+
original_output = gr.Image(
|
| 404 |
+
label="Processed Original",
|
| 405 |
+
elem_classes=["result-gallery"],
|
| 406 |
+
show_label=False
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
status_output = gr.Textbox(
|
| 410 |
+
label="Status",
|
| 411 |
+
value="Ready to create! Upload an image and describe your vision.",
|
| 412 |
+
interactive=False,
|
| 413 |
+
elem_classes=["status-panel", "status-ready"]
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
with gr.Row():
|
| 417 |
+
download_btn = gr.DownloadButton(
|
| 418 |
+
"Download Result",
|
| 419 |
+
value=None,
|
| 420 |
+
visible=False,
|
| 421 |
+
elem_classes=["secondary-button"]
|
| 422 |
+
)
|
| 423 |
+
clear_btn = gr.Button(
|
| 424 |
+
"Clear All",
|
| 425 |
+
elem_classes=["secondary-button"]
|
| 426 |
+
)
|
| 427 |
+
memory_btn = gr.Button(
|
| 428 |
+
"Clean Memory",
|
| 429 |
+
elem_classes=["secondary-button"]
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Footer with tech credits
|
| 433 |
+
gr.HTML("""
|
| 434 |
+
<div class="app-footer">
|
| 435 |
+
<div class="footer-powered">
|
| 436 |
+
<p class="footer-powered-title">Powered By</p>
|
| 437 |
+
<div class="footer-tech-grid">
|
| 438 |
+
<span class="footer-tech-item">Stable Diffusion XL</span>
|
| 439 |
+
<span class="footer-tech-item">OpenCLIP</span>
|
| 440 |
+
<span class="footer-tech-item">BiRefNet</span>
|
| 441 |
+
<span class="footer-tech-item">rembg</span>
|
| 442 |
+
<span class="footer-tech-item">PyTorch</span>
|
| 443 |
+
<span class="footer-tech-item">Gradio</span>
|
| 444 |
+
</div>
|
| 445 |
+
</div>
|
| 446 |
+
<div class="footer-divider"></div>
|
| 447 |
+
<p class="footer-copyright">
|
| 448 |
+
SceneWeaver © 2025 |
|
| 449 |
+
Built with <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">SDXL</a>
|
| 450 |
+
and <a href="https://github.com/mlfoundations/open_clip" target="_blank">OpenCLIP</a>
|
| 451 |
+
</p>
|
| 452 |
+
</div>
|
| 453 |
+
""")
|
| 454 |
+
|
| 455 |
+
# Event handlers
|
| 456 |
+
# Template selection handler
|
| 457 |
+
template_dropdown.change(
|
| 458 |
+
fn=self.apply_template,
|
| 459 |
+
inputs=[template_dropdown, negative_prompt],
|
| 460 |
+
outputs=[prompt_input, negative_prompt, guidance_slider]
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
generate_btn.click(
|
| 464 |
+
fn=self.generate_handler,
|
| 465 |
+
inputs=[
|
| 466 |
+
uploaded_image,
|
| 467 |
+
prompt_input,
|
| 468 |
+
combination_mode,
|
| 469 |
+
focus_mode,
|
| 470 |
+
negative_prompt,
|
| 471 |
+
steps_slider,
|
| 472 |
+
guidance_slider
|
| 473 |
+
],
|
| 474 |
+
outputs=[
|
| 475 |
+
combined_output,
|
| 476 |
+
generated_output,
|
| 477 |
+
original_output,
|
| 478 |
+
status_output,
|
| 479 |
+
download_btn
|
| 480 |
+
]
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
clear_btn.click(
|
| 484 |
+
fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)),
|
| 485 |
+
outputs=[combined_output, generated_output, original_output, status_output, download_btn]
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
memory_btn.click(
|
| 489 |
+
fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!",
|
| 490 |
+
outputs=[status_output]
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
combined_output.change(
|
| 494 |
+
fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False),
|
| 495 |
+
inputs=[combined_output],
|
| 496 |
+
outputs=[download_btn]
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
return interface
|
| 500 |
+
|
| 501 |
+
def launch(self, share: bool = True, debug: bool = False):
|
| 502 |
+
"""Launch the UI interface"""
|
| 503 |
+
interface = self.create_interface()
|
| 504 |
+
|
| 505 |
+
return interface.launch(
|
| 506 |
+
share=share,
|
| 507 |
+
debug=debug,
|
| 508 |
+
show_error=True,
|
| 509 |
+
height=800,
|
| 510 |
+
favicon_path=None,
|
| 511 |
+
ssl_verify=False,
|
| 512 |
+
quiet=False
|
| 513 |
+
)
|