DawnC commited on
Commit
ca80d1d
·
verified ·
1 Parent(s): f3522f4

Upload 10 files

Browse files
Files changed (10) hide show
  1. app.py +82 -0
  2. css_styles.py +513 -0
  3. image_blender.py +802 -0
  4. mask_generator.py +650 -0
  5. model_manager.py +293 -0
  6. quality_checker.py +409 -0
  7. requirements.txt +81 -0
  8. scene_templates.py +429 -0
  9. scene_weaver_core.py +808 -0
  10. ui_manager.py +513 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+
5
+ from ui_manager import UIManager
6
+
7
+ def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
8
+ """Launch SceneWeaver Application"""
9
+
10
+ print("🎨 Starting SceneWeaver...")
11
+ print("✨ AI-Powered Image Background Generation")
12
+
13
+ try:
14
+ # Test imports first
15
+ print("🔍 Testing imports...")
16
+ try:
17
+ # Test creating UIManager
18
+ print("🔍 Creating UIManager instance...")
19
+ ui = UIManager()
20
+ print("✅ UIManager instance created successfully")
21
+
22
+ # Launch UI
23
+ print("🚀 Launching interface...")
24
+ interface = ui.launch(share=share, debug=debug)
25
+ print("✅ Interface launched successfully")
26
+ return interface
27
+
28
+ except ImportError as import_error:
29
+ import traceback
30
+ print(f"❌ Import failed: {import_error}")
31
+ print(f"Traceback: {traceback.format_exc()}")
32
+ raise
33
+
34
+ except Exception as e:
35
+ import traceback
36
+ print(f"❌ Failed to launch: {e}")
37
+ print(f"Full traceback: {traceback.format_exc()}")
38
+ raise
39
+
40
+ def launch_ui(share: bool = True, debug: bool = False):
41
+ """Convenience function for Jupyter notebooks"""
42
+ return launch_final_blend_sceneweaver(share=share, debug=debug)
43
+
44
+ def main():
45
+ """Main entry point"""
46
+
47
+ # Check if running in Jupyter/Colab
48
+ try:
49
+ get_ipython()
50
+ is_jupyter = True
51
+ except NameError:
52
+ is_jupyter = False
53
+
54
+ if not is_jupyter and len(sys.argv) > 1 and not any('-f' in arg for arg in sys.argv):
55
+ # Command line mode with arguments
56
+ share = '--no-share' not in sys.argv
57
+ debug = '--debug' in sys.argv
58
+ else:
59
+ # Default mode
60
+ share = True
61
+ debug = False
62
+
63
+ try:
64
+ interface = launch_final_blend_sceneweaver(share=share, debug=debug)
65
+
66
+ if not is_jupyter:
67
+ print("🛑 Press Ctrl+C to stop")
68
+ try:
69
+ interface.block_thread()
70
+ except KeyboardInterrupt:
71
+ print("👋 Stopped")
72
+
73
+ return interface
74
+
75
+ except Exception as e:
76
+ print(f"❌ Error: {e}")
77
+ if not is_jupyter:
78
+ sys.exit(1)
79
+ raise
80
+
81
+ if __name__ == "__main__":
82
+ main()
css_styles.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CSSStyles:
2
+ """
3
+ CSS styling configuration for the SceneWeaver application.
4
+ Professional design system with clean typography and modern aesthetics.
5
+ """
6
+
7
+ @staticmethod
8
+ def get_main_css() -> str:
9
+ """
10
+ Get the main CSS styling for the application.
11
+
12
+ Returns:
13
+ Complete CSS string for Gradio interface styling
14
+ """
15
+ return """
16
+ /* Import professional fonts */
17
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
18
+
19
+ /* CSS Variables - Professional color system */
20
+ :root {
21
+ /* Primary brand colors */
22
+ --primary-color: #1e3a5f;
23
+ --primary-hover: #2d5a8a;
24
+ --primary-light: #e8f4fd;
25
+
26
+ /* Accent colors */
27
+ --accent-color: #3b82f6;
28
+ --accent-hover: #2563eb;
29
+ --accent-light: #dbeafe;
30
+
31
+ /* Status colors */
32
+ --success-color: #10b981;
33
+ --warning-color: #f59e0b;
34
+ --error-color: #ef4444;
35
+
36
+ /* Neutral colors */
37
+ --bg-primary: #ffffff;
38
+ --bg-secondary: #f8fafc;
39
+ --bg-tertiary: #f1f5f9;
40
+ --text-primary: #1e293b;
41
+ --text-secondary: #475569;
42
+ --text-muted: #94a3b8;
43
+ --border-color: #e2e8f0;
44
+ --border-light: #f1f5f9;
45
+
46
+ /* Shadows */
47
+ --shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
48
+ --shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -2px rgba(0, 0, 0, 0.1);
49
+ --shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -4px rgba(0, 0, 0, 0.1);
50
+ --shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 8px 10px -6px rgba(0, 0, 0, 0.1);
51
+
52
+ /* Border radius */
53
+ --radius-sm: 6px;
54
+ --radius-md: 8px;
55
+ --radius-lg: 12px;
56
+ --radius-xl: 16px;
57
+
58
+ /* Transitions */
59
+ --transition-fast: 150ms ease;
60
+ --transition-normal: 250ms ease;
61
+ --transition-slow: 350ms ease;
62
+ }
63
+
64
+ /* Global styles */
65
+ * {
66
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
67
+ -webkit-font-smoothing: antialiased !important;
68
+ -moz-osx-font-smoothing: grayscale !important;
69
+ }
70
+
71
+ /* Main container */
72
+ .gradio-container {
73
+ background: linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%) !important;
74
+ min-height: 100vh !important;
75
+ padding: 24px !important;
76
+ }
77
+
78
+ /* ===== HEADER SECTION ===== */
79
+ .main-header {
80
+ text-align: center !important;
81
+ padding: 48px 32px 40px !important;
82
+ margin-bottom: 32px !important;
83
+ background: linear-gradient(135deg, var(--bg-primary) 0%, var(--bg-secondary) 100%) !important;
84
+ border-radius: var(--radius-xl) !important;
85
+ box-shadow: var(--shadow-md) !important;
86
+ border: 1px solid var(--border-light) !important;
87
+ }
88
+
89
+ .main-title {
90
+ font-size: 2.75rem !important;
91
+ font-weight: 700 !important;
92
+ color: var(--primary-color) !important;
93
+ margin: 0 0 12px 0 !important;
94
+ letter-spacing: -0.03em !important;
95
+ display: flex !important;
96
+ align-items: center !important;
97
+ justify-content: center !important;
98
+ gap: 14px !important;
99
+ line-height: 1.2 !important;
100
+ }
101
+
102
+ .title-emoji {
103
+ font-size: 2.5rem !important;
104
+ filter: drop-shadow(0 2px 4px rgba(0,0,0,0.15)) !important;
105
+ transition: transform var(--transition-normal) !important;
106
+ }
107
+
108
+ .title-emoji:hover {
109
+ transform: scale(1.1) rotate(-5deg) !important;
110
+ }
111
+
112
+ .main-subtitle {
113
+ font-size: 1.1rem !important;
114
+ color: var(--text-secondary) !important;
115
+ font-weight: 400 !important;
116
+ margin: 0 !important;
117
+ line-height: 1.5 !important;
118
+ max-width: 700px !important;
119
+ margin: 0 auto !important;
120
+ }
121
+
122
+ /* ===== CARD SYSTEM ===== */
123
+ .feature-card {
124
+ background: var(--bg-primary) !important;
125
+ border: 1px solid var(--border-color) !important;
126
+ border-radius: var(--radius-lg) !important;
127
+ padding: 24px !important;
128
+ margin-bottom: 20px !important;
129
+ box-shadow: var(--shadow-sm) !important;
130
+ transition: all var(--transition-normal) !important;
131
+ position: relative !important;
132
+ }
133
+
134
+ .feature-card:hover {
135
+ border-color: var(--accent-color) !important;
136
+ box-shadow: var(--shadow-lg) !important;
137
+ transform: translateY(-2px) !important;
138
+ }
139
+
140
+ .card-title {
141
+ font-size: 1.25rem !important;
142
+ font-weight: 600 !important;
143
+ color: var(--text-primary) !important;
144
+ margin-bottom: 16px !important;
145
+ display: flex !important;
146
+ align-items: center !important;
147
+ gap: 10px !important;
148
+ }
149
+
150
+ .section-emoji {
151
+ font-size: 1.2rem !important;
152
+ filter: drop-shadow(0 1px 2px rgba(0,0,0,0.1)) !important;
153
+ }
154
+
155
+ /* ===== INPUT COMPONENTS ===== */
156
+ .input-field {
157
+ border: 1px solid var(--border-color) !important;
158
+ border-radius: var(--radius-md) !important;
159
+ background: var(--bg-primary) !important;
160
+ transition: all var(--transition-fast) !important;
161
+ }
162
+
163
+ .input-field:focus-within {
164
+ border-color: var(--accent-color) !important;
165
+ box-shadow: 0 0 0 3px var(--accent-light) !important;
166
+ }
167
+
168
+ /* ===== BUTTONS ===== */
169
+ .primary-button {
170
+ background: linear-gradient(135deg, var(--accent-color) 0%, var(--accent-hover) 100%) !important;
171
+ color: white !important;
172
+ border: none !important;
173
+ border-radius: var(--radius-md) !important;
174
+ padding: 14px 28px !important;
175
+ font-size: 1rem !important;
176
+ font-weight: 600 !important;
177
+ cursor: pointer !important;
178
+ transition: all var(--transition-normal) !important;
179
+ box-shadow: var(--shadow-md) !important;
180
+ }
181
+
182
+ .primary-button:hover {
183
+ transform: translateY(-2px) !important;
184
+ box-shadow: var(--shadow-lg) !important;
185
+ filter: brightness(1.05) !important;
186
+ }
187
+
188
+ .secondary-button {
189
+ background: var(--bg-primary) !important;
190
+ color: var(--accent-color) !important;
191
+ border: 1.5px solid var(--accent-color) !important;
192
+ border-radius: var(--radius-md) !important;
193
+ padding: 12px 20px !important;
194
+ font-size: 0.95rem !important;
195
+ font-weight: 500 !important;
196
+ cursor: pointer !important;
197
+ transition: all var(--transition-fast) !important;
198
+ }
199
+
200
+ .secondary-button:hover {
201
+ background: var(--accent-light) !important;
202
+ transform: translateY(-1px) !important;
203
+ }
204
+
205
+ /* ===== RESULTS GALLERY ===== */
206
+ #results-gallery-centered {
207
+ display: flex !important;
208
+ flex-direction: column !important;
209
+ align-items: center !important;
210
+ }
211
+
212
+ #results-gallery-centered .gradio-tabs {
213
+ width: 100% !important;
214
+ }
215
+
216
+ .result-gallery {
217
+ border-radius: var(--radius-lg) !important;
218
+ overflow: hidden !important;
219
+ border: 1px solid var(--border-color) !important;
220
+ box-shadow: var(--shadow-md) !important;
221
+ }
222
+
223
+ .result-gallery img {
224
+ width: 100% !important;
225
+ height: auto !important;
226
+ object-fit: contain !important;
227
+ }
228
+
229
+ /* ===== STATUS PANEL ===== */
230
+ .status-panel {
231
+ background: var(--bg-secondary) !important;
232
+ border: 1px solid var(--border-color) !important;
233
+ border-radius: var(--radius-md) !important;
234
+ padding: 12px 16px !important;
235
+ margin: 16px 0 !important;
236
+ }
237
+
238
+ .status-ready {
239
+ color: var(--success-color) !important;
240
+ font-weight: 500 !important;
241
+ }
242
+
243
+ /* ===== LOADING NOTICE ===== */
244
+ .loading-notice {
245
+ background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%) !important;
246
+ border: 1px solid #f59e0b !important;
247
+ border-radius: var(--radius-md) !important;
248
+ padding: 14px 18px !important;
249
+ margin: 16px 0 !important;
250
+ display: flex !important;
251
+ align-items: center !important;
252
+ gap: 12px !important;
253
+ }
254
+
255
+ .loading-notice-icon {
256
+ font-size: 1.3rem !important;
257
+ flex-shrink: 0 !important;
258
+ }
259
+
260
+ .loading-notice-text {
261
+ color: #92400e !important;
262
+ font-size: 0.9rem !important;
263
+ font-weight: 500 !important;
264
+ line-height: 1.5 !important;
265
+ }
266
+
267
+ /* ===== QUICK START GUIDE ===== */
268
+ .user-guidance-panel {
269
+ background: var(--bg-secondary) !important;
270
+ border: 1px solid var(--border-color) !important;
271
+ border-radius: var(--radius-md) !important;
272
+ margin: 16px 0 !important;
273
+ overflow: hidden !important;
274
+ }
275
+
276
+ .guidance-summary {
277
+ background: var(--bg-primary) !important;
278
+ padding: 12px 16px !important;
279
+ cursor: pointer !important;
280
+ font-weight: 500 !important;
281
+ color: var(--text-primary) !important;
282
+ transition: background var(--transition-fast) !important;
283
+ list-style: none !important;
284
+ display: flex !important;
285
+ align-items: center !important;
286
+ gap: 8px !important;
287
+ border-bottom: 1px solid var(--border-color) !important;
288
+ }
289
+
290
+ .guidance-summary:hover {
291
+ background: var(--accent-light) !important;
292
+ }
293
+
294
+ .guidance-summary::-webkit-details-marker {
295
+ display: none !important;
296
+ }
297
+
298
+ .guidance-content {
299
+ padding: 16px !important;
300
+ color: var(--text-secondary) !important;
301
+ line-height: 1.6 !important;
302
+ }
303
+
304
+ .guidance-content p {
305
+ margin: 8px 0 !important;
306
+ font-size: 0.9rem !important;
307
+ }
308
+
309
+ .guidance-content strong {
310
+ color: var(--primary-color) !important;
311
+ font-weight: 600 !important;
312
+ }
313
+
314
+ /* ===== FOOTER ===== */
315
+ .app-footer {
316
+ background: var(--bg-primary) !important;
317
+ border: 1px solid var(--border-color) !important;
318
+ border-radius: var(--radius-lg) !important;
319
+ padding: 32px !important;
320
+ margin-top: 32px !important;
321
+ text-align: center !important;
322
+ }
323
+
324
+ .footer-powered {
325
+ margin-bottom: 20px !important;
326
+ }
327
+
328
+ .footer-powered-title {
329
+ font-size: 0.85rem !important;
330
+ font-weight: 500 !important;
331
+ color: var(--text-muted) !important;
332
+ text-transform: uppercase !important;
333
+ letter-spacing: 0.1em !important;
334
+ margin-bottom: 16px !important;
335
+ }
336
+
337
+ .footer-tech-grid {
338
+ display: flex !important;
339
+ flex-wrap: wrap !important;
340
+ justify-content: center !important;
341
+ gap: 12px !important;
342
+ margin-bottom: 24px !important;
343
+ }
344
+
345
+ .footer-tech-item {
346
+ background: var(--bg-secondary) !important;
347
+ border: 1px solid var(--border-color) !important;
348
+ border-radius: var(--radius-sm) !important;
349
+ padding: 8px 16px !important;
350
+ font-size: 0.85rem !important;
351
+ font-weight: 500 !important;
352
+ color: var(--text-secondary) !important;
353
+ transition: all var(--transition-fast) !important;
354
+ }
355
+
356
+ .footer-tech-item:hover {
357
+ background: var(--accent-light) !important;
358
+ border-color: var(--accent-color) !important;
359
+ color: var(--accent-color) !important;
360
+ }
361
+
362
+ .footer-divider {
363
+ height: 1px !important;
364
+ background: var(--border-color) !important;
365
+ margin: 20px auto !important;
366
+ max-width: 400px !important;
367
+ }
368
+
369
+ .footer-copyright {
370
+ font-size: 0.85rem !important;
371
+ color: var(--text-muted) !important;
372
+ font-weight: 400 !important;
373
+ }
374
+
375
+ .footer-copyright a {
376
+ color: var(--accent-color) !important;
377
+ text-decoration: none !important;
378
+ font-weight: 500 !important;
379
+ }
380
+
381
+ .footer-copyright a:hover {
382
+ text-decoration: underline !important;
383
+ }
384
+
385
+ /* ===== TABS STYLING ===== */
386
+ .gradio-tabs {
387
+ border: none !important;
388
+ }
389
+
390
+ .gradio-tabs > .tab-nav {
391
+ background: var(--bg-secondary) !important;
392
+ border-radius: var(--radius-md) !important;
393
+ padding: 4px !important;
394
+ gap: 4px !important;
395
+ border: 1px solid var(--border-color) !important;
396
+ margin-bottom: 16px !important;
397
+ }
398
+
399
+ .gradio-tabs > .tab-nav > button {
400
+ border-radius: var(--radius-sm) !important;
401
+ padding: 10px 20px !important;
402
+ font-weight: 500 !important;
403
+ font-size: 0.9rem !important;
404
+ transition: all var(--transition-fast) !important;
405
+ border: none !important;
406
+ background: transparent !important;
407
+ color: var(--text-secondary) !important;
408
+ }
409
+
410
+ .gradio-tabs > .tab-nav > button.selected {
411
+ background: var(--bg-primary) !important;
412
+ color: var(--accent-color) !important;
413
+ box-shadow: var(--shadow-sm) !important;
414
+ }
415
+
416
+ .gradio-tabs > .tab-nav > button:hover:not(.selected) {
417
+ background: var(--bg-primary) !important;
418
+ color: var(--text-primary) !important;
419
+ }
420
+
421
+ /* ===== ACCORDION ===== */
422
+ .gradio-accordion {
423
+ border: 1px solid var(--border-color) !important;
424
+ border-radius: var(--radius-md) !important;
425
+ overflow: hidden !important;
426
+ margin-top: 12px !important;
427
+ }
428
+
429
+ .gradio-accordion > .label-wrap {
430
+ background: var(--bg-secondary) !important;
431
+ padding: 12px 16px !important;
432
+ font-weight: 500 !important;
433
+ }
434
+
435
+ .gradio-accordion > .label-wrap:hover {
436
+ background: var(--bg-tertiary) !important;
437
+ }
438
+
439
+ /* ===== RESPONSIVE ===== */
440
+ @media (max-width: 768px) {
441
+ .main-title {
442
+ font-size: 2rem !important;
443
+ }
444
+
445
+ .main-subtitle {
446
+ font-size: 1rem !important;
447
+ }
448
+
449
+ .footer-tech-grid {
450
+ gap: 8px !important;
451
+ }
452
+
453
+ .footer-tech-item {
454
+ padding: 6px 12px !important;
455
+ font-size: 0.8rem !important;
456
+ }
457
+ }
458
+
459
+ /* ===== EMOJI ENHANCEMENT ===== */
460
+ .emoji-enhanced {
461
+ display: inline-block !important;
462
+ font-style: normal !important;
463
+ filter: drop-shadow(0 1px 2px rgba(0,0,0,0.1)) !important;
464
+ transition: transform var(--transition-fast) !important;
465
+ }
466
+
467
+ .emoji-enhanced:hover {
468
+ transform: scale(1.1) !important;
469
+ }
470
+
471
+ /* ===== IMAGE DISPLAY FIX ===== */
472
+ .gradio-image {
473
+ min-height: 200px !important;
474
+ }
475
+
476
+ .gradio-image img {
477
+ max-height: 500px !important;
478
+ object-fit: contain !important;
479
+ }
480
+
481
+ /* ===== SCENE TEMPLATE DROPDOWN ===== */
482
+ .template-dropdown {
483
+ margin: 8px 0 !important;
484
+ }
485
+
486
+ .template-dropdown select,
487
+ .template-dropdown input {
488
+ font-size: 0.95rem !important;
489
+ padding: 10px 14px !important;
490
+ border-radius: var(--radius-md) !important;
491
+ border: 1px solid var(--border-color) !important;
492
+ background: var(--bg-primary) !important;
493
+ transition: all var(--transition-fast) !important;
494
+ }
495
+
496
+ .template-dropdown select:hover,
497
+ .template-dropdown input:hover {
498
+ border-color: var(--accent-color) !important;
499
+ }
500
+
501
+ .template-dropdown select:focus,
502
+ .template-dropdown input:focus {
503
+ border-color: var(--accent-color) !important;
504
+ box-shadow: 0 0 0 3px var(--accent-light) !important;
505
+ outline: none !important;
506
+ }
507
+
508
+ /* Dropdown option styling */
509
+ .template-dropdown option {
510
+ padding: 8px 12px !important;
511
+ font-size: 0.95rem !important;
512
+ }
513
+ """
image_blender.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import logging
5
+ from typing import Dict, Any, Optional, Tuple
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logger.setLevel(logging.INFO)
9
+
10
+ class ImageBlender:
11
+ """
12
+ Advanced image blending with aggressive spill suppression and color replacement
13
+ Completely eliminates yellow edge residue while maintaining sharp edges
14
+ """
15
+
16
+ EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge (reduced to protect more foreground)
17
+ ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization (increased to keep more foreground)
18
+ DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground detection
19
+ FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value above which pixels are strongly protected
20
+ BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background color detection
21
+
22
+ def __init__(self, enable_multi_scale: bool = True):
23
+ """
24
+ Initialize ImageBlender.
25
+
26
+ Args:
27
+ enable_multi_scale: Whether to enable multi-scale edge refinement (default True)
28
+ """
29
+ self.enable_multi_scale = enable_multi_scale
30
+ self._debug_info = {}
31
+ self._adaptive_strength_map = None
32
+
33
+ def _erode_mask_edges(
34
+ self,
35
+ mask_array: np.ndarray,
36
+ erosion_pixels: int = 2
37
+ ) -> np.ndarray:
38
+ """
39
+ Erode mask edges to remove contaminated boundary pixels.
40
+
41
+ This removes the outermost pixels of the foreground mask where
42
+ color contamination from the original background is most likely.
43
+
44
+ Args:
45
+ mask_array: Input mask as numpy array (uint8, 0-255)
46
+ erosion_pixels: Number of pixels to erode (default 2)
47
+
48
+ Returns:
49
+ Eroded mask array (uint8)
50
+ """
51
+ if erosion_pixels <= 0:
52
+ return mask_array
53
+
54
+ # Use elliptical kernel for natural-looking erosion
55
+ kernel_size = max(2, erosion_pixels)
56
+ kernel = cv2.getStructuringElement(
57
+ cv2.MORPH_ELLIPSE,
58
+ (kernel_size, kernel_size)
59
+ )
60
+
61
+ # Apply erosion
62
+ eroded = cv2.erode(mask_array, kernel, iterations=1)
63
+
64
+ # Slight blur to smooth the eroded edges
65
+ eroded = cv2.GaussianBlur(eroded, (3, 3), 0)
66
+
67
+ logger.debug(f"Mask erosion applied: {erosion_pixels}px, kernel size: {kernel_size}")
68
+ return eroded
69
+
70
+ def _binarize_edge_alpha(
71
+ self,
72
+ alpha: np.ndarray,
73
+ mask_array: np.ndarray,
74
+ orig_array: np.ndarray,
75
+ threshold: float = 0.45
76
+ ) -> np.ndarray:
77
+ """
78
+ Binarize semi-transparent edge pixels to eliminate color bleeding.
79
+
80
+ Semi-transparent pixels at edges cause visible contamination because
81
+ they blend the original (potentially dark) foreground with the new
82
+ background. This method forces edge pixels to be either fully opaque
83
+ or fully transparent.
84
+
85
+ Args:
86
+ alpha: Current alpha channel (float32, 0.0-1.0)
87
+ mask_array: Original mask array (uint8, 0-255)
88
+ orig_array: Original foreground image array (uint8, RGB)
89
+ threshold: Alpha threshold for binarization decision (default 0.45)
90
+
91
+ Returns:
92
+ Modified alpha array with binarized edges (float32)
93
+ """
94
+ # Identify semi-transparent edge zone (not fully opaque, not fully transparent)
95
+ edge_zone = (alpha > 0.05) & (alpha < 0.95)
96
+
97
+ if not np.any(edge_zone):
98
+ return alpha
99
+
100
+ # Calculate local foreground luminance for adaptive thresholding
101
+ gray = np.mean(orig_array, axis=2)
102
+
103
+ # For dark foreground pixels, use slightly higher threshold
104
+ # to preserve more of the dark subject
105
+ is_dark = gray < self.DARK_LUMINANCE_THRESHOLD
106
+
107
+ # Create adaptive threshold map
108
+ adaptive_threshold = np.full_like(alpha, threshold)
109
+ adaptive_threshold[is_dark] = threshold + 0.1 # Keep more dark pixels
110
+
111
+ # Binarize: above threshold -> opaque, below -> transparent
112
+ alpha_binarized = alpha.copy()
113
+
114
+ # Pixels above threshold become fully opaque
115
+ make_opaque = edge_zone & (alpha > adaptive_threshold)
116
+ alpha_binarized[make_opaque] = 1.0
117
+
118
+ # Pixels below threshold become fully transparent
119
+ make_transparent = edge_zone & (alpha <= adaptive_threshold)
120
+ alpha_binarized[make_transparent] = 0.0
121
+
122
+ # Log statistics
123
+ num_opaque = np.sum(make_opaque)
124
+ num_transparent = np.sum(make_transparent)
125
+ logger.info(f"Edge binarization: {num_opaque} pixels -> opaque, {num_transparent} pixels -> transparent")
126
+
127
+ return alpha_binarized
128
+
129
+ def _apply_edge_cleanup(
130
+ self,
131
+ result_array: np.ndarray,
132
+ bg_array: np.ndarray,
133
+ alpha: np.ndarray,
134
+ cleanup_width: int = 2
135
+ ) -> np.ndarray:
136
+ """
137
+ Final cleanup pass to remove any remaining edge artifacts.
138
+
139
+ Detects remaining semi-transparent edges and replaces them with
140
+ either pure foreground or pure background colors.
141
+
142
+ Args:
143
+ result_array: Current blended result (uint8, RGB)
144
+ bg_array: Background image array (uint8, RGB)
145
+ alpha: Final alpha channel (float32, 0.0-1.0)
146
+ cleanup_width: Width of edge zone to clean (default 2)
147
+
148
+ Returns:
149
+ Cleaned result array (uint8)
150
+ """
151
+ # Find edge pixels that might still have artifacts
152
+ # These are pixels with alpha close to but not exactly 0 or 1
153
+ residual_edge = (alpha > 0.01) & (alpha < 0.99) & (alpha != 0.0) & (alpha != 1.0)
154
+
155
+ if not np.any(residual_edge):
156
+ return result_array
157
+
158
+ result_cleaned = result_array.copy()
159
+
160
+ # For residual edge pixels, snap to nearest pure state
161
+ snap_to_bg = residual_edge & (alpha < 0.5)
162
+ snap_to_fg = residual_edge & (alpha >= 0.5)
163
+
164
+ # Replace with background
165
+ result_cleaned[snap_to_bg] = bg_array[snap_to_bg]
166
+
167
+ # For foreground, keep original but ensure no blending artifacts
168
+ # (already handled by the blend, so no action needed for snap_to_fg)
169
+
170
+ num_cleaned = np.sum(residual_edge)
171
+ if num_cleaned > 0:
172
+ logger.debug(f"Edge cleanup: {num_cleaned} residual pixels cleaned")
173
+
174
+ return result_cleaned
175
+
176
+ def _remove_background_color_contamination(
177
+ self,
178
+ image_array: np.ndarray,
179
+ mask_array: np.ndarray,
180
+ orig_bg_color_lab: np.ndarray,
181
+ tolerance: float = 30.0
182
+ ) -> np.ndarray:
183
+ """
184
+ Remove original background color contamination from foreground pixels.
185
+
186
+ Scans the foreground area for pixels that match the original background
187
+ color and replaces them with nearby clean foreground colors.
188
+
189
+ Args:
190
+ image_array: Foreground image array (uint8, RGB)
191
+ mask_array: Mask array (uint8, 0-255)
192
+ orig_bg_color_lab: Original background color in Lab space
193
+ tolerance: DeltaE tolerance for detecting contaminated pixels
194
+
195
+ Returns:
196
+ Cleaned image array (uint8)
197
+ """
198
+ # Convert to Lab for color comparison
199
+ image_lab = cv2.cvtColor(image_array, cv2.COLOR_RGB2LAB).astype(np.float32)
200
+
201
+ # Only process foreground pixels (mask > 50)
202
+ foreground_mask = mask_array > 50
203
+
204
+ if not np.any(foreground_mask):
205
+ return image_array
206
+
207
+ # Calculate deltaE from original background color for all pixels
208
+ delta_l = image_lab[:, :, 0] - orig_bg_color_lab[0]
209
+ delta_a = image_lab[:, :, 1] - orig_bg_color_lab[1]
210
+ delta_b = image_lab[:, :, 2] - orig_bg_color_lab[2]
211
+ delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
212
+
213
+ # Find contaminated pixels: in foreground but color similar to original background
214
+ contaminated = foreground_mask & (delta_e < tolerance)
215
+
216
+ if not np.any(contaminated):
217
+ logger.debug("No background color contamination detected in foreground")
218
+ return image_array
219
+
220
+ num_contaminated = np.sum(contaminated)
221
+ logger.info(f"Found {num_contaminated} pixels with background color contamination")
222
+
223
+ # Create output array
224
+ result = image_array.copy()
225
+
226
+ # For contaminated pixels, use inpainting to replace with surrounding colors
227
+ inpaint_mask = contaminated.astype(np.uint8) * 255
228
+
229
+ try:
230
+ # Use inpainting to fill contaminated areas with surrounding foreground colors
231
+ result = cv2.inpaint(result, inpaint_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
232
+ logger.info(f"Inpainted {num_contaminated} contaminated pixels")
233
+ except Exception as e:
234
+ logger.warning(f"Inpainting failed: {e}, using median filter fallback")
235
+ # Fallback: apply median filter to contaminated areas
236
+ median_filtered = cv2.medianBlur(image_array, 5)
237
+ result[contaminated] = median_filtered[contaminated]
238
+
239
+ return result
240
+
241
+ def _protect_foreground_core(
242
+ self,
243
+ result_array: np.ndarray,
244
+ orig_array: np.ndarray,
245
+ mask_array: np.ndarray,
246
+ protection_threshold: int = 140
247
+ ) -> np.ndarray:
248
+ """
249
+ Strongly protect core foreground pixels from any background influence.
250
+
251
+ For pixels with high mask confidence, directly use the original foreground
252
+ color without any blending, ensuring faces and bodies are not affected.
253
+
254
+ Args:
255
+ result_array: Current blended result (uint8, RGB)
256
+ orig_array: Original foreground image (uint8, RGB)
257
+ mask_array: Mask array (uint8, 0-255)
258
+ protection_threshold: Mask value above which pixels are fully protected
259
+
260
+ Returns:
261
+ Protected result array (uint8)
262
+ """
263
+ # Identify strongly protected foreground pixels
264
+ strong_foreground = mask_array >= protection_threshold
265
+
266
+ if not np.any(strong_foreground):
267
+ return result_array
268
+
269
+ # For these pixels, use original foreground color directly
270
+ result_protected = result_array.copy()
271
+ result_protected[strong_foreground] = orig_array[strong_foreground]
272
+
273
+ num_protected = np.sum(strong_foreground)
274
+ logger.info(f"Protected {num_protected} core foreground pixels from background influence")
275
+
276
+ return result_protected
277
+
278
+ def multi_scale_edge_refinement(
279
+ self,
280
+ original_image: Image.Image,
281
+ background_image: Image.Image,
282
+ mask: Image.Image
283
+ ) -> Image.Image:
284
+ """
285
+ Multi-scale edge refinement for better edge quality.
286
+ Uses image pyramid to handle edges at different scales.
287
+
288
+ Args:
289
+ original_image: Foreground PIL Image
290
+ background_image: Background PIL Image
291
+ mask: Current mask PIL Image
292
+
293
+ Returns:
294
+ Refined mask PIL Image
295
+ """
296
+ logger.info("🔍 Starting multi-scale edge refinement...")
297
+
298
+ try:
299
+ # Convert to numpy arrays
300
+ orig_array = np.array(original_image.convert('RGB'))
301
+ mask_array = np.array(mask).astype(np.float32)
302
+ height, width = mask_array.shape
303
+
304
+ # Define scales for pyramid
305
+ scales = [1.0, 0.5, 0.25] # Original, half, quarter
306
+ scale_masks = []
307
+ scale_complexities = []
308
+
309
+ # Convert to grayscale for edge detection
310
+ gray = cv2.cvtColor(orig_array, cv2.COLOR_RGB2GRAY)
311
+
312
+ for scale in scales:
313
+ if scale == 1.0:
314
+ scaled_gray = gray
315
+ scaled_mask = mask_array
316
+ else:
317
+ new_h = int(height * scale)
318
+ new_w = int(width * scale)
319
+ scaled_gray = cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
320
+ scaled_mask = cv2.resize(mask_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
321
+
322
+ # Compute local complexity using gradient standard deviation
323
+ sobel_x = cv2.Sobel(scaled_gray, cv2.CV_64F, 1, 0, ksize=3)
324
+ sobel_y = cv2.Sobel(scaled_gray, cv2.CV_64F, 0, 1, ksize=3)
325
+ gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
326
+
327
+ # Calculate local complexity in 5x5 regions
328
+ kernel_size = 5
329
+ complexity = cv2.blur(gradient_mag, (kernel_size, kernel_size))
330
+
331
+ # Resize back to original size
332
+ if scale != 1.0:
333
+ scaled_mask = cv2.resize(scaled_mask, (width, height), interpolation=cv2.INTER_LANCZOS4)
334
+ complexity = cv2.resize(complexity, (width, height), interpolation=cv2.INTER_LANCZOS4)
335
+
336
+ scale_masks.append(scaled_mask)
337
+ scale_complexities.append(complexity)
338
+
339
+ # Compute weights based on complexity
340
+ # High complexity -> use high resolution mask
341
+ # Low complexity -> use low resolution mask (smoother)
342
+ weights = np.zeros((len(scales), height, width), dtype=np.float32)
343
+
344
+ # Normalize complexities
345
+ max_complexity = max(c.max() for c in scale_complexities) + 1e-6
346
+ normalized_complexities = [c / max_complexity for c in scale_complexities]
347
+
348
+ # Weight assignment: higher complexity at each scale means that scale is more reliable
349
+ for i, complexity in enumerate(normalized_complexities):
350
+ if i == 0: # High resolution - prefer for high complexity regions
351
+ weights[i] = complexity
352
+ elif i == 1: # Medium resolution - moderate complexity
353
+ weights[i] = 0.5 * (1 - complexity) + 0.5 * complexity * 0.5
354
+ else: # Low resolution - prefer for low complexity regions
355
+ weights[i] = 1 - complexity
356
+
357
+ # Normalize weights so they sum to 1 at each pixel
358
+ weight_sum = weights.sum(axis=0, keepdims=True) + 1e-6
359
+ weights = weights / weight_sum
360
+
361
+ # Weighted blend of masks from different scales
362
+ refined_mask = np.zeros((height, width), dtype=np.float32)
363
+ for i, mask_i in enumerate(scale_masks):
364
+ refined_mask += weights[i] * mask_i
365
+
366
+ # Clip and convert to uint8
367
+ refined_mask = np.clip(refined_mask, 0, 255).astype(np.uint8)
368
+
369
+ logger.info("✅ Multi-scale edge refinement completed")
370
+ return Image.fromarray(refined_mask, mode='L')
371
+
372
+ except Exception as e:
373
+ logger.error(f"❌ Multi-scale refinement failed: {e}, using original mask")
374
+ return mask
375
+
376
+ def simple_blend_images(
377
+ self,
378
+ original_image: Image.Image,
379
+ background_image: Image.Image,
380
+ combination_mask: Image.Image,
381
+ use_multi_scale: Optional[bool] = None
382
+ ) -> Image.Image:
383
+ """
384
+ Aggressive spill suppression + color replacement: completely eliminate yellow edge residue, maintain sharp edges
385
+
386
+ Args:
387
+ original_image: Foreground PIL Image
388
+ background_image: Background PIL Image
389
+ combination_mask: Mask PIL Image (L mode)
390
+ use_multi_scale: Override for multi-scale refinement (None = use class default)
391
+
392
+ Returns:
393
+ Blended PIL Image
394
+ """
395
+ logger.info("🎨 Starting advanced image blending process...")
396
+
397
+ # Apply multi-scale edge refinement if enabled
398
+ should_use_multi_scale = use_multi_scale if use_multi_scale is not None else self.enable_multi_scale
399
+ if should_use_multi_scale:
400
+ combination_mask = self.multi_scale_edge_refinement(
401
+ original_image, background_image, combination_mask
402
+ )
403
+
404
+ # Convert to numpy arrays
405
+ orig_array = np.array(original_image, dtype=np.uint8)
406
+ bg_array = np.array(background_image, dtype=np.uint8)
407
+ mask_array = np.array(combination_mask, dtype=np.uint8)
408
+
409
+ logger.info(f"📊 Image dimensions - Original: {orig_array.shape}, Background: {bg_array.shape}, Mask: {mask_array.shape}")
410
+ logger.info(f"📊 Mask statistics (before erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
411
+
412
+ # === NEW: Apply mask erosion to remove contaminated edge pixels ===
413
+ mask_array = self._erode_mask_edges(mask_array, self.EDGE_EROSION_PIXELS)
414
+ logger.info(f"📊 Mask statistics (after erosion) - Mean: {mask_array.mean():.1f}, Min: {mask_array.min()}, Max: {mask_array.max()}")
415
+
416
+ # Enhanced parameters for better spill suppression
417
+ RING_WIDTH_PX = 4 # Increased ring width for better coverage
418
+ SPILL_STRENGTH = 0.85 # Stronger spill suppression
419
+ L_MATCH_STRENGTH = 0.65 # Stronger luminance matching
420
+ DELTAE_THRESHOLD = 18 # More aggressive contamination detection
421
+ HARD_EDGE_PROTECT = True # Black edge protection
422
+ INPAINT_FALLBACK = True # inpaint fallback repair
423
+ MULTI_PASS_CORRECTION = True # Enable multi-pass correction
424
+
425
+ # Estimate original background color and foreground representative color ===
426
+ height, width = orig_array.shape[:2]
427
+
428
+ # Take 15px from each side to estimate original background color
429
+ edge_width = 15
430
+ border_pixels = []
431
+
432
+ # Collect border pixels (excluding foreground areas)
433
+ border_mask = np.zeros((height, width), dtype=bool)
434
+ border_mask[:edge_width, :] = True # Top edge
435
+ border_mask[-edge_width:, :] = True # Bottom edge
436
+ border_mask[:, :edge_width] = True # Left edge
437
+ border_mask[:, -edge_width:] = True # Right edge
438
+
439
+ # Exclude foreground areas
440
+ fg_binary = mask_array > 50
441
+ border_mask = border_mask & (~fg_binary)
442
+
443
+ if np.any(border_mask):
444
+ border_pixels = orig_array[border_mask].reshape(-1, 3)
445
+
446
+ # Simplified background color estimation (no sklearn dependency)
447
+ try:
448
+ if len(border_pixels) > 100:
449
+ # Use histogram to find mode colors
450
+ # Quantize RGB to coarser grid to find main colors
451
+ quantized = (border_pixels // 32) * 32 # 8-level quantization
452
+
453
+ # Find most frequent color
454
+ unique_colors, counts = np.unique(quantized.reshape(-1, quantized.shape[-1]),
455
+ axis=0, return_counts=True)
456
+ most_common_idx = np.argmax(counts)
457
+ orig_bg_color_rgb = unique_colors[most_common_idx].astype(np.uint8)
458
+ else:
459
+ orig_bg_color_rgb = np.median(border_pixels, axis=0).astype(np.uint8)
460
+ except:
461
+ # Fallback: use four corners average
462
+ corners = np.array([orig_array[0,0], orig_array[0,-1],
463
+ orig_array[-1,0], orig_array[-1,-1]])
464
+ orig_bg_color_rgb = np.mean(corners, axis=0).astype(np.uint8)
465
+ else:
466
+ orig_bg_color_rgb = np.array([200, 180, 120], dtype=np.uint8) # Default yellow
467
+
468
+ # Convert to Lab space
469
+ orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
470
+ logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
471
+
472
+ # Remove original background color contamination from foreground
473
+ orig_array = self._remove_background_color_contamination(
474
+ orig_array,
475
+ mask_array,
476
+ orig_bg_color_lab,
477
+ tolerance=self.BACKGROUND_COLOR_TOLERANCE
478
+ )
479
+
480
+ # Redefine trimap, optimized for cartoon characters
481
+ try:
482
+ kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
483
+
484
+ # FG_CORE: Reduce erosion iterations from 2 to 1 to avoid losing thin limbs
485
+ mask_eroded_once = cv2.erode(mask_array, kernel_3x3, iterations=1)
486
+ fg_core = mask_eroded_once > 127 # Adjustable parameter: erosion iterations
487
+
488
+ # RING: Use morphological gradient to redefine, ensuring only thin edge band
489
+ mask_dilated = cv2.dilate(mask_array, kernel_3x3, iterations=1)
490
+ mask_eroded = cv2.erode(mask_array, kernel_3x3, iterations=1)
491
+
492
+ # Ensure consistent data types to avoid overflow
493
+ morphological_gradient = cv2.subtract(mask_dilated, mask_eroded)
494
+ ring_zone = morphological_gradient > 0 # Areas with morphological gradient > 0 are edge bands
495
+
496
+ # BG: background area
497
+ bg_zone = mask_array < 30
498
+
499
+ logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
500
+
501
+ except Exception as e:
502
+ import traceback
503
+ logger.error(f"❌ Trimap definition failed: {e}")
504
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
505
+ print(f"❌ TRIMAP ERROR: {e}")
506
+ print(f"Traceback: {traceback.format_exc()}")
507
+ # Fallback to simple definition
508
+ fg_core = mask_array > 200
509
+ ring_zone = (mask_array > 50) & (mask_array <= 200)
510
+ bg_zone = mask_array <= 50
511
+
512
+ # Foreground representative color: estimated from FG_CORE
513
+ if np.any(fg_core):
514
+ fg_pixels = orig_array[fg_core].reshape(-1, 3)
515
+ fg_rep_color_rgb = np.median(fg_pixels, axis=0).astype(np.uint8)
516
+ else:
517
+ fg_rep_color_rgb = np.array([80, 60, 40], dtype=np.uint8) # Default dark
518
+
519
+ fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
520
+
521
+ # Edge band spill suppression and repair
522
+ if np.any(ring_zone):
523
+ # Convert to Lab space
524
+ orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
525
+ orig_array_working = orig_array.copy().astype(np.float32)
526
+
527
+ # ΔE detect contaminated pixels
528
+ ring_pixels_lab = orig_lab[ring_zone]
529
+
530
+ # Calculate ΔE with original background color (simplified version)
531
+ delta_l = ring_pixels_lab[:, 0] - orig_bg_color_lab[0]
532
+ delta_a = ring_pixels_lab[:, 1] - orig_bg_color_lab[1]
533
+ delta_b = ring_pixels_lab[:, 2] - orig_bg_color_lab[2]
534
+ delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
535
+
536
+ # Contaminated pixel mask
537
+ contaminated_mask = delta_e < DELTAE_THRESHOLD
538
+
539
+ if np.any(contaminated_mask):
540
+ # Calculate adaptive strength based on delta_e for each pixel
541
+ # Pixels closer to background color get stronger correction
542
+ contaminated_delta_e = delta_e[contaminated_mask]
543
+
544
+ # Adaptive strength formula: inverse relationship with delta_e
545
+ # Pixels very close to bg color (low delta_e) -> strong correction
546
+ # Pixels further from bg color (high delta_e) -> lighter correction
547
+ adaptive_strength = SPILL_STRENGTH * np.maximum(
548
+ 0.0,
549
+ 1.0 - (contaminated_delta_e / DELTAE_THRESHOLD)
550
+ )
551
+
552
+ # Clamp adaptive strength to reasonable range (30% - 100% of base strength)
553
+ min_strength = SPILL_STRENGTH * 0.3
554
+ adaptive_strength = np.clip(adaptive_strength, min_strength, SPILL_STRENGTH)
555
+
556
+ # Store for debug visualization
557
+ self._adaptive_strength_map = np.zeros_like(delta_e)
558
+ self._adaptive_strength_map[contaminated_mask] = adaptive_strength
559
+
560
+ logger.info(f"📊 Adaptive strength stats - Mean: {adaptive_strength.mean():.3f}, Min: {adaptive_strength.min():.3f}, Max: {adaptive_strength.max():.3f}")
561
+
562
+ # Chroma vector deprojection
563
+ bg_chroma = np.array([orig_bg_color_lab[1], orig_bg_color_lab[2]])
564
+ bg_chroma_norm = bg_chroma / (np.linalg.norm(bg_chroma) + 1e-6)
565
+
566
+ # Color correction for contaminated pixels
567
+ contaminated_pixels = ring_pixels_lab[contaminated_mask]
568
+
569
+ # Remove background chroma component with adaptive strength (per-pixel)
570
+ pixel_chroma = contaminated_pixels[:, 1:3] # a, b channels
571
+ projection = np.dot(pixel_chroma, bg_chroma_norm)[:, np.newaxis] * bg_chroma_norm
572
+
573
+ # Apply adaptive strength per pixel
574
+ adaptive_strength_2d = adaptive_strength[:, np.newaxis]
575
+ corrected_chroma = pixel_chroma - projection * adaptive_strength_2d
576
+
577
+ # Converge toward foreground representative color with adaptive strength
578
+ convergence_factor = adaptive_strength_2d * 0.6
579
+ corrected_chroma = (corrected_chroma * (1 - convergence_factor) +
580
+ fg_rep_color_lab[1:3] * convergence_factor)
581
+
582
+ # Adaptive luminance matching
583
+ adaptive_l_strength = adaptive_strength * (L_MATCH_STRENGTH / SPILL_STRENGTH)
584
+ corrected_l = (contaminated_pixels[:, 0] * (1 - adaptive_l_strength) +
585
+ fg_rep_color_lab[0] * adaptive_l_strength)
586
+
587
+ # Update Lab values
588
+ ring_pixels_lab[contaminated_mask, 0] = corrected_l
589
+ ring_pixels_lab[contaminated_mask, 1:3] = corrected_chroma
590
+
591
+ # Write back to original image
592
+ orig_lab[ring_zone] = ring_pixels_lab
593
+
594
+ # Dark edge protection
595
+ if HARD_EDGE_PROTECT:
596
+ gray = np.mean(orig_array, axis=2)
597
+ # Detect dark and high gradient areas
598
+ sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
599
+ sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
600
+ gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
601
+
602
+ dark_edge_zone = ring_zone & (gray < 60) & (gradient_mag > 20)
603
+ # Protect these areas from excessive modification, copy directly from original
604
+ if np.any(dark_edge_zone):
605
+ orig_lab[dark_edge_zone] = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB)[dark_edge_zone]
606
+
607
+ # Multi-pass correction for stubborn spill
608
+ if MULTI_PASS_CORRECTION:
609
+ # Second pass for remaining contamination
610
+ ring_pixels_lab_pass2 = orig_lab[ring_zone]
611
+ delta_l_pass2 = ring_pixels_lab_pass2[:, 0] - orig_bg_color_lab[0]
612
+ delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
613
+ delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
614
+ delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
615
+
616
+ still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
617
+
618
+ if np.any(still_contaminated):
619
+ # Apply stronger correction to remaining contaminated pixels
620
+ remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
621
+
622
+ # More aggressive chroma neutralization
623
+ remaining_chroma = remaining_pixels[:, 1:3]
624
+ neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
625
+
626
+ # Stronger luminance matching
627
+ neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
628
+
629
+ ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
630
+ ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
631
+ orig_lab[ring_zone] = ring_pixels_lab_pass2
632
+
633
+ # Convert back to RGB
634
+ orig_lab_clipped = np.clip(orig_lab, 0, 255).astype(np.uint8)
635
+ orig_array_corrected = cv2.cvtColor(orig_lab_clipped, cv2.COLOR_LAB2RGB)
636
+
637
+ # inpaint fallback repair
638
+ if INPAINT_FALLBACK:
639
+ # inpaint still contaminated outermost pixels
640
+ final_contaminated = ring_zone.copy()
641
+
642
+ # Check if there's still contamination after repair
643
+ final_lab = cv2.cvtColor(orig_array_corrected, cv2.COLOR_RGB2LAB).astype(np.float32)
644
+ final_ring_lab = final_lab[ring_zone]
645
+ final_delta_l = final_ring_lab[:, 0] - orig_bg_color_lab[0]
646
+ final_delta_a = final_ring_lab[:, 1] - orig_bg_color_lab[1]
647
+ final_delta_b = final_ring_lab[:, 2] - orig_bg_color_lab[2]
648
+ final_delta_e = np.sqrt(final_delta_l**2 + final_delta_a**2 + final_delta_b**2)
649
+
650
+ still_contaminated = final_delta_e < (DELTAE_THRESHOLD * 0.5)
651
+ if np.any(still_contaminated):
652
+ # Create inpaint mask
653
+ inpaint_mask = np.zeros((height, width), dtype=np.uint8)
654
+ ring_coords = np.where(ring_zone)
655
+ inpaint_coords = (ring_coords[0][still_contaminated], ring_coords[1][still_contaminated])
656
+ inpaint_mask[inpaint_coords] = 255
657
+
658
+ # Execute inpaint
659
+ try:
660
+ orig_array_corrected = cv2.inpaint(orig_array_corrected, inpaint_mask, 3, cv2.INPAINT_TELEA)
661
+ except:
662
+ # Fallback: directly cover with foreground representative color
663
+ orig_array_corrected[inpaint_coords] = fg_rep_color_rgb
664
+
665
+ orig_array = orig_array_corrected
666
+
667
+ # === Linear space blending (keep original logic) ===
668
+ def srgb_to_linear(img):
669
+ img_f = img.astype(np.float32) / 255.0
670
+ return np.where(img_f <= 0.04045, img_f / 12.92, np.power((img_f + 0.055) / 1.055, 2.4))
671
+
672
+ def linear_to_srgb(img):
673
+ img_clipped = np.clip(img, 0, 1)
674
+ return np.where(img_clipped <= 0.0031308,
675
+ 12.92 * img_clipped,
676
+ 1.055 * np.power(img_clipped, 1/2.4) - 0.055)
677
+
678
+ orig_linear = srgb_to_linear(orig_array)
679
+ bg_linear = srgb_to_linear(bg_array)
680
+
681
+ # === Cartoon-optimized Alpha calculation ===
682
+ alpha = mask_array.astype(np.float32) / 255.0
683
+
684
+ # Core foreground region - fully opaque
685
+ alpha[fg_core] = 1.0
686
+
687
+ # Background region - fully transparent
688
+ alpha[bg_zone] = 0.0
689
+
690
+ # [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
691
+ high_confidence_pixels = mask_array >= 160
692
+ alpha[high_confidence_pixels] = 1.0
693
+ logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
694
+
695
+ # Ring area can be dehaloed, but doesn't affect already set high confidence pixels
696
+ ring_without_high_conf = ring_zone & (~high_confidence_pixels)
697
+ alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
698
+
699
+ # Retain existing black outline/strong edge protection
700
+ orig_gray = np.mean(orig_array, axis=2)
701
+
702
+ # Detect strong edge areas
703
+ sobel_x = cv2.Sobel(orig_gray, cv2.CV_64F, 1, 0, ksize=3)
704
+ sobel_y = cv2.Sobel(orig_gray, cv2.CV_64F, 0, 1, ksize=3)
705
+ gradient_mag = np.sqrt(sobel_x**2 + sobel_y**2)
706
+
707
+ # Black outline/strong edge protection: nearly fully opaque
708
+ black_edge_threshold = 60 # black edge threshold
709
+ gradient_threshold = 25 # gradient threshold
710
+ strong_edges = (orig_gray < black_edge_threshold) & (gradient_mag > gradient_threshold) & (mask_array > 10)
711
+ alpha[strong_edges] = np.maximum(alpha[strong_edges], 0.995) # black edge alpha
712
+
713
+ logger.info(f"🛡️ Protection applied - High conf: {high_confidence_pixels.sum()}, Strong edges: {strong_edges.sum()}")
714
+
715
+ # Apply edge alpha binarization to eliminate semi-transparent artifacts
716
+ alpha = self._binarize_edge_alpha(
717
+ alpha,
718
+ mask_array,
719
+ orig_array,
720
+ threshold=self.ALPHA_BINARIZE_THRESHOLD
721
+ )
722
+
723
+ # Final blending
724
+ alpha_3d = alpha[:, :, np.newaxis]
725
+ result_linear = orig_linear * alpha_3d + bg_linear * (1 - alpha_3d)
726
+ result_srgb = linear_to_srgb(result_linear)
727
+ result_array = (result_srgb * 255).astype(np.uint8)
728
+
729
+ # Final edge cleanup pass
730
+ result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
731
+
732
+ # Protect core foreground from any background influence
733
+ # This ensures faces and bodies retain original colors
734
+ result_array = self._protect_foreground_core(
735
+ result_array,
736
+ np.array(original_image, dtype=np.uint8), # Use original unprocessed image
737
+ mask_array,
738
+ protection_threshold=self.FOREGROUND_PROTECTION_THRESHOLD
739
+ )
740
+
741
+ # Store debug information (for debug output)
742
+ self._debug_info = {
743
+ 'orig_bg_color_rgb': orig_bg_color_rgb,
744
+ 'fg_rep_color_rgb': fg_rep_color_rgb,
745
+ 'orig_bg_color_lab': orig_bg_color_lab,
746
+ 'fg_rep_color_lab': fg_rep_color_lab,
747
+ 'ring_zone': ring_zone,
748
+ 'fg_core': fg_core,
749
+ 'alpha_final': alpha
750
+ }
751
+
752
+ return Image.fromarray(result_array)
753
+
754
+ def create_debug_images(
755
+ self,
756
+ original_image: Image.Image,
757
+ generated_background: Image.Image,
758
+ combination_mask: Image.Image,
759
+ combined_image: Image.Image
760
+ ) -> Dict[str, Image.Image]:
761
+ """
762
+ Generate debug images: (a) Final mask grayscale (b) Alpha heatmap (c) Ring visualization overlay
763
+ """
764
+ debug_images = {}
765
+
766
+ # Final Mask grayscale
767
+ debug_images["mask_gray"] = combination_mask.convert('L')
768
+
769
+ # Alpha Heatmap
770
+ mask_array = np.array(combination_mask.convert('L'))
771
+ heatmap_colored = cv2.applyColorMap(mask_array, cv2.COLORMAP_JET)
772
+ heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
773
+ debug_images["alpha_heatmap"] = Image.fromarray(heatmap_rgb)
774
+
775
+ # Ring visualization overlay - show ring areas on original image
776
+ if hasattr(self, '_debug_info') and 'ring_zone' in self._debug_info:
777
+ ring_zone = self._debug_info['ring_zone']
778
+ orig_array = np.array(original_image)
779
+ ring_overlay = orig_array.copy()
780
+
781
+ # Mark ring areas with red semi-transparent overlay
782
+ ring_overlay[ring_zone] = ring_overlay[ring_zone] * 0.7 + np.array([255, 0, 0]) * 0.3
783
+ debug_images["ring_visualization"] = Image.fromarray(ring_overlay.astype(np.uint8))
784
+ else:
785
+ # If no ring information, use original image
786
+ debug_images["ring_visualization"] = original_image
787
+
788
+ # Adaptive strength heatmap - visualize per-pixel correction strength
789
+ if hasattr(self, '_adaptive_strength_map') and self._adaptive_strength_map is not None:
790
+ # Normalize adaptive strength to 0-255 for visualization
791
+ strength_map = self._adaptive_strength_map
792
+ if strength_map.max() > 0:
793
+ normalized_strength = (strength_map / strength_map.max() * 255).astype(np.uint8)
794
+ else:
795
+ normalized_strength = np.zeros_like(strength_map, dtype=np.uint8)
796
+
797
+ # Apply colormap
798
+ strength_heatmap = cv2.applyColorMap(normalized_strength, cv2.COLORMAP_VIRIDIS)
799
+ strength_heatmap_rgb = cv2.cvtColor(strength_heatmap, cv2.COLOR_BGR2RGB)
800
+ debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
801
+
802
+ return debug_images
mask_generator.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image, ImageFilter, ImageDraw
4
+ import logging
5
+ from typing import Optional, Tuple
6
+ from scipy.ndimage import binary_erosion, binary_dilation
7
+ import io
8
+ import gc
9
+ import torch
10
+ from transformers import AutoModelForImageSegmentation
11
+ from torchvision import transforms
12
+ from rembg import remove, new_session
13
+
14
+ logger = logging.getLogger(__name__)
15
+ logger.setLevel(logging.INFO)
16
+
17
+ class MaskGenerator:
18
+ """
19
+ Intelligent mask generation using deep learning models with traditional fallback.
20
+ Priority: BiRefNet > U²-Net (rembg) > Traditional gradient-based methods
21
+ """
22
+
23
+ def __init__(self, max_image_size: int = 1024, device: str = "auto"):
24
+ self.max_image_size = max_image_size
25
+ self.device = self._setup_device(device)
26
+
27
+ # BiRefNet model (lazy loading)
28
+ self._birefnet_model = None
29
+ self._birefnet_transform = None
30
+
31
+ # Log initialization
32
+ logger.info(f"🎭 MaskGenerator initialized on {self.device}")
33
+
34
+ def _setup_device(self, device: str) -> str:
35
+ """Setup computation device"""
36
+ if device == "auto":
37
+ if torch.cuda.is_available():
38
+ return "cuda"
39
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
40
+ return "mps"
41
+ return "cpu"
42
+ return device
43
+
44
+ def _load_birefnet_model(self) -> bool:
45
+ """
46
+ Lazy load BiRefNet model for memory efficiency.
47
+ Returns True if model loaded successfully, False otherwise.
48
+ """
49
+ if self._birefnet_model is not None:
50
+ return True
51
+
52
+ try:
53
+ logger.info("📥 Loading BiRefNet model (ZhengPeng7/BiRefNet)...")
54
+
55
+ # Load model with fp16 for memory efficiency on GPU
56
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
57
+
58
+ self._birefnet_model = AutoModelForImageSegmentation.from_pretrained(
59
+ "ZhengPeng7/BiRefNet",
60
+ trust_remote_code=True,
61
+ torch_dtype=dtype
62
+ )
63
+ self._birefnet_model.to(self.device)
64
+ self._birefnet_model.eval()
65
+
66
+ # Define preprocessing transform
67
+ self._birefnet_transform = transforms.Compose([
68
+ transforms.Resize((1024, 1024)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
71
+ ])
72
+
73
+ logger.info("✅ BiRefNet model loaded successfully")
74
+ return True
75
+
76
+ except Exception as e:
77
+ logger.error(f"❌ Failed to load BiRefNet: {e}")
78
+ self._birefnet_model = None
79
+ self._birefnet_transform = None
80
+ return False
81
+
82
+ def _unload_birefnet_model(self):
83
+ """Unload BiRefNet model to free memory"""
84
+ if self._birefnet_model is not None:
85
+ del self._birefnet_model
86
+ self._birefnet_model = None
87
+ self._birefnet_transform = None
88
+
89
+ if torch.cuda.is_available():
90
+ torch.cuda.empty_cache()
91
+ gc.collect()
92
+ logger.info("🧹 BiRefNet model unloaded")
93
+
94
+ def apply_guided_filter(
95
+ self,
96
+ mask: np.ndarray,
97
+ guide_image: Image.Image,
98
+ radius: int = 8,
99
+ eps: float = 0.01
100
+ ) -> np.ndarray:
101
+ """
102
+ Apply guided filter to mask for edge-preserving smoothing.
103
+ Falls back to Gaussian blur if guided filter is not available.
104
+
105
+ Args:
106
+ mask: Input mask as numpy array (0-255)
107
+ guide_image: Original image to use as guide
108
+ radius: Filter radius (larger = more smoothing)
109
+ eps: Regularization parameter (smaller = more edge-preserving)
110
+
111
+ Returns:
112
+ Filtered mask as numpy array (0-255)
113
+ """
114
+ try:
115
+ # Convert guide image to grayscale
116
+ guide_gray = np.array(guide_image.convert('L')).astype(np.float32) / 255.0
117
+ mask_float = mask.astype(np.float32) / 255.0
118
+
119
+ logger.info(f"🔧 Applying guided filter (radius={radius}, eps={eps})")
120
+
121
+ # Apply guided filter
122
+ filtered = cv2.ximgproc.guidedFilter(
123
+ guide=guide_gray,
124
+ src=mask_float,
125
+ radius=radius,
126
+ eps=eps
127
+ )
128
+
129
+ # Convert back to 0-255 range
130
+ result = (np.clip(filtered, 0, 1) * 255).astype(np.uint8)
131
+ logger.info("✅ Guided filter applied successfully")
132
+ return result
133
+
134
+ except Exception as e:
135
+ logger.error(f"❌ Guided filter failed: {e}, using original mask")
136
+ return mask
137
+
138
+ def try_birefnet_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
139
+ """
140
+ Generate foreground mask using BiRefNet model.
141
+ BiRefNet provides high-quality segmentation with clean edges.
142
+
143
+ Args:
144
+ original_image: Input PIL Image
145
+
146
+ Returns:
147
+ PIL Image (L mode) mask or None if failed
148
+ """
149
+ try:
150
+ # Lazy load model
151
+ if not self._load_birefnet_model():
152
+ return None
153
+
154
+ logger.info("🤖 Starting BiRefNet foreground extraction...")
155
+ original_size = original_image.size
156
+
157
+ # Convert to RGB if needed
158
+ if original_image.mode != 'RGB':
159
+ image_rgb = original_image.convert('RGB')
160
+ else:
161
+ image_rgb = original_image
162
+
163
+ # Preprocess image
164
+ input_tensor = self._birefnet_transform(image_rgb).unsqueeze(0)
165
+
166
+ # Move to device with appropriate dtype
167
+ if self.device == "cuda":
168
+ input_tensor = input_tensor.to(self.device, dtype=torch.float16)
169
+ else:
170
+ input_tensor = input_tensor.to(self.device)
171
+
172
+ # Run inference
173
+ with torch.no_grad():
174
+ outputs = self._birefnet_model(input_tensor)
175
+
176
+ # BiRefNet outputs a list, get the final prediction
177
+ if isinstance(outputs, (list, tuple)):
178
+ pred = outputs[-1]
179
+ else:
180
+ pred = outputs
181
+
182
+ # Sigmoid to get probability map
183
+ pred = torch.sigmoid(pred)
184
+
185
+ # Convert to numpy
186
+ pred_np = pred.squeeze().cpu().numpy()
187
+
188
+ # Convert to 0-255 range
189
+ mask_array = (pred_np * 255).astype(np.uint8)
190
+
191
+ # Resize back to original size
192
+ mask_pil = Image.fromarray(mask_array, mode='L')
193
+ mask_pil = mask_pil.resize(original_size, Image.LANCZOS)
194
+ mask_array = np.array(mask_pil)
195
+
196
+ # Quality check
197
+ mean_val = mask_array.mean()
198
+ nonzero_ratio = np.count_nonzero(mask_array > 50) / mask_array.size
199
+
200
+ logger.info(f"📊 BiRefNet mask stats - Mean: {mean_val:.1f}, Coverage: {nonzero_ratio:.1%}")
201
+
202
+ if mean_val < 10:
203
+ logger.warning("⚠️ BiRefNet mask too weak, falling back")
204
+ return None
205
+
206
+ if nonzero_ratio < 0.03:
207
+ logger.warning("⚠️ BiRefNet foreground coverage too low, falling back")
208
+ return None
209
+
210
+ # Light post-processing for edge refinement
211
+ # Use morphological operations to clean up
212
+ kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
213
+ mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_CLOSE, kernel_small)
214
+
215
+ logger.info("✅ BiRefNet mask generation successful!")
216
+ return Image.fromarray(mask_array, mode='L')
217
+
218
+ except torch.cuda.OutOfMemoryError:
219
+ logger.error("❌ BiRefNet: GPU memory exhausted")
220
+ self._unload_birefnet_model()
221
+ return None
222
+
223
+ except Exception as e:
224
+ logger.error(f"❌ BiRefNet mask generation failed: {e}")
225
+ import traceback
226
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
227
+ return None
228
+
229
+ def try_deep_learning_mask(self, original_image: Image.Image) -> Optional[Image.Image]:
230
+ """
231
+ Intelligent foreground extraction with model priority:
232
+ 1. BiRefNet (best quality, clean edges)
233
+ 2. U²-Net via rembg (good fallback)
234
+ 3. Return None to trigger traditional methods
235
+
236
+ Args:
237
+ original_image: Input PIL Image
238
+
239
+ Returns:
240
+ PIL Image (L mode) mask or None if all methods failed
241
+ """
242
+ # Priority 1: Try BiRefNet first
243
+ logger.info("🤖 Attempting BiRefNet mask generation...")
244
+ birefnet_mask = self.try_birefnet_mask(original_image)
245
+ if birefnet_mask is not None:
246
+ logger.info("✅ Using BiRefNet generated mask")
247
+ return birefnet_mask
248
+
249
+ # Priority 2: Fallback to rembg (U²-Net)
250
+ logger.info("🔄 BiRefNet unavailable/failed, trying rembg...")
251
+ try:
252
+ logger.info("🤖 Starting rembg foreground extraction")
253
+
254
+ # Try u2net first (better for cartoons/objects like Snoopy)
255
+ try:
256
+ session = new_session('u2net')
257
+ logger.info("✅ Using u2net model")
258
+ except Exception as e:
259
+ logger.warning(f"u2net failed ({e}), trying u2net_human_seg")
260
+ try:
261
+ session = new_session('u2net_human_seg')
262
+ logger.info("✅ Using u2net_human_seg model")
263
+ except Exception as e2:
264
+ logger.error(f"All rembg models failed: {e2}")
265
+ return None
266
+
267
+ # Convert image to bytes for rembg
268
+ img_byte_arr = io.BytesIO()
269
+ original_image.save(img_byte_arr, format='PNG')
270
+ img_byte_arr = img_byte_arr.getvalue()
271
+ logger.info(f"📷 Image size: {len(img_byte_arr)} bytes")
272
+
273
+ # Perform background removal
274
+ result = remove(img_byte_arr, session=session)
275
+ result_img = Image.open(io.BytesIO(result)).convert('RGBA')
276
+ alpha_channel = result_img.split()[-1]
277
+ alpha_array = np.array(alpha_channel)
278
+
279
+ logger.info(f"📊 Raw alpha stats - Mean: {alpha_array.mean():.1f}, Min: {alpha_array.min()}, Max: {alpha_array.max()}")
280
+
281
+ # Step 1: Light smoothing to reduce noise but preserve edges
282
+ alpha_smoothed = cv2.GaussianBlur(alpha_array, (3, 3), 0.8)
283
+
284
+ # Step 2: Contrast stretching to utilize full range
285
+ alpha_stretched = cv2.normalize(alpha_smoothed, None, 0, 255, cv2.NORM_MINMAX)
286
+
287
+ # Step 3: CRITICAL FIX - More aggressive foreground preservation
288
+ # Instead of hard threshold, use adaptive approach
289
+
290
+ # Find the main subject area (high confidence regions)
291
+ high_confidence = alpha_stretched > 180
292
+ medium_confidence = (alpha_stretched > 60) & (alpha_stretched <= 180)
293
+ low_confidence = (alpha_stretched > 15) & (alpha_stretched <= 60)
294
+
295
+ # Create final mask with better extremity handling
296
+ final_alpha = np.zeros_like(alpha_stretched)
297
+
298
+ # High confidence areas - keep at full opacity
299
+ final_alpha[high_confidence] = 255
300
+
301
+ # Medium confidence - boost significantly
302
+ final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
303
+
304
+ # Low confidence - moderate boost (catches faint extremities)
305
+ final_alpha[low_confidence] = np.clip(alpha_stretched[low_confidence] * 2.5, 120, 199)
306
+
307
+ # Morphological operations to connect disconnected parts (hands, feet, tail)
308
+ kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
309
+ kernel_medium = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
310
+
311
+ # Close small gaps (helps connect separated body parts)
312
+ final_alpha = cv2.morphologyEx(final_alpha, cv2.MORPH_CLOSE, kernel_small, iterations=1)
313
+
314
+ # Light dilation to ensure nothing gets cut off
315
+ final_alpha = cv2.dilate(final_alpha, kernel_small, iterations=1)
316
+
317
+ logger.info(f"📊 Final alpha stats - Mean: {final_alpha.mean():.1f}, Min: {final_alpha.min()}, Max: {final_alpha.max()}")
318
+
319
+ # Quality check - but be more lenient for cartoon characters
320
+ if final_alpha.mean() < 10:
321
+ logger.warning("⚠️ Alpha still too weak, falling back to traditional method")
322
+ return None
323
+
324
+ # Enhanced post-processing for cartoon characters
325
+ is_cartoon = self._detect_cartoon_character(original_image, final_alpha)
326
+
327
+ if is_cartoon:
328
+ logger.info("🎭 Detected cartoon/character image, applying specialized processing")
329
+ final_alpha = self._enhance_cartoon_mask(original_image, final_alpha)
330
+
331
+ # Count non-zero pixels to ensure we have substantial foreground
332
+ foreground_pixels = np.count_nonzero(final_alpha > 50)
333
+ total_pixels = final_alpha.size
334
+ foreground_ratio = foreground_pixels / total_pixels
335
+ logger.info(f"📊 Foreground coverage: {foreground_ratio:.1%} of image")
336
+
337
+ if foreground_ratio < 0.05: # Less than 5% is probably too little
338
+ logger.warning("⚠️ Very low foreground coverage, falling back to traditional method")
339
+ return None
340
+
341
+ mask = Image.fromarray(final_alpha.astype(np.uint8), mode='L')
342
+ logger.info("✅ Enhanced rembg mask generation successful!")
343
+ return mask
344
+
345
+ except Exception as e:
346
+ logger.error(f"❌ Deep learning mask extraction failed: {e}")
347
+ return None
348
+
349
+ def _detect_cartoon_character(self, original_image: Image.Image, alpha_mask: np.ndarray) -> bool:
350
+ """
351
+ Detect if image is cartoon/line art (heuristic approach)
352
+ """
353
+ try:
354
+ img_array = np.array(original_image.convert('RGB'))
355
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
356
+
357
+ # Calculate edge density (cartoons usually have more clear edges)
358
+ edges = cv2.Canny(gray, 50, 150)
359
+ edge_density = np.count_nonzero(edges) / max(edges.size, 1) # Avoid division by zero
360
+
361
+ # Calculate color complexity (cartoons usually have fewer colors) - optimize memory usage
362
+ h, w, c = img_array.shape
363
+ if h * w > 100000: # If image is too large, resize for processing
364
+ small_img = cv2.resize(img_array, (200, 200))
365
+ else:
366
+ small_img = img_array
367
+
368
+ unique_colors = len(np.unique(small_img.reshape(-1, 3), axis=0))
369
+ total_pixels = small_img.shape[0] * small_img.shape[1]
370
+ color_simplicity = unique_colors < (total_pixels * 0.1)
371
+
372
+ # Check for obvious black outlines
373
+ dark_pixels_ratio = np.count_nonzero(gray < 50) / max(gray.size, 1) # Avoid division by zero
374
+ has_black_outline = dark_pixels_ratio > 0.05
375
+
376
+ # Comprehensive judgment: high edge density + color simplicity + black outline = likely cartoon
377
+ is_cartoon = (edge_density > 0.05) and (color_simplicity or has_black_outline)
378
+
379
+ logger.info(f"🔍 Cartoon detection - Edge density: {edge_density:.3f}, Color simplicity: {color_simplicity}, Black outline: {has_black_outline} -> Cartoon: {is_cartoon}")
380
+ return is_cartoon
381
+
382
+ except Exception as e:
383
+ import traceback
384
+ logger.error(f"❌ Cartoon detection failed: {e}")
385
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
386
+ print(f"❌ CARTOON DETECTION ERROR: {e}")
387
+ print(f"Traceback: {traceback.format_exc()}")
388
+ return False
389
+
390
+ def _enhance_cartoon_mask(self, original_image: Image.Image, alpha_mask: np.ndarray) -> np.ndarray:
391
+ """
392
+ Enhanced mask processing for cartoon characters
393
+ """
394
+ try:
395
+ img_array = np.array(original_image.convert('RGB'))
396
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
397
+ enhanced_alpha = alpha_mask.copy()
398
+
399
+ # Step 1: Black outline enhancement - find black outlines and enhance their alpha
400
+ th_dark = 80 # Adjustable parameter: black threshold
401
+ black_outline = gray < th_dark
402
+
403
+ # Dilate black outline region by 1px
404
+ kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # Adjustable parameter: dilation kernel size
405
+ black_outline_dilated = cv2.dilate(black_outline.astype(np.uint8), kernel_dilate, iterations=1)
406
+
407
+ # Set black outline region alpha directly to 255
408
+ enhanced_alpha[black_outline_dilated > 0] = 255
409
+ logger.info(f"🖤 Black outline enhanced: {np.count_nonzero(black_outline_dilated)} pixels")
410
+
411
+ # Step 2: Simplified internal enhancement - process white fill areas within outlines
412
+ # Find high confidence regions (alpha ≥ 160)
413
+ high_confidence = enhanced_alpha >= 160
414
+
415
+ # Apply close operation on high confidence regions to connect separated parts
416
+ kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Adjustable parameter: close kernel size
417
+ high_confidence_closed = cv2.morphologyEx(high_confidence.astype(np.uint8), cv2.MORPH_CLOSE, kernel_close, iterations=1)
418
+
419
+ # Simplified approach: directly enhance medium confidence regions without complex flood fill
420
+ # Find medium/low confidence regions surrounded by high confidence regions
421
+ medium_confidence = (enhanced_alpha >= 80) & (enhanced_alpha < 160)
422
+
423
+ # Dilate high confidence region to include more internal areas
424
+ kernel_dilate_internal = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
425
+ high_confidence_expanded = cv2.dilate(high_confidence_closed, kernel_dilate_internal, iterations=1)
426
+
427
+ # Medium confidence pixels within expanded high confidence areas are considered internal fill
428
+ internal_fill_regions = medium_confidence & (high_confidence_expanded > 0)
429
+
430
+ # Enhance alpha of these internal fill regions to at least 220
431
+ min_alpha_for_fill = 220 # Adjustable parameter: minimum alpha for internal fill
432
+ enhanced_alpha[internal_fill_regions] = np.maximum(enhanced_alpha[internal_fill_regions], min_alpha_for_fill)
433
+
434
+ logger.info(f"🤍 Internal fill regions enhanced: {np.count_nonzero(internal_fill_regions)} pixels")
435
+ logger.info(f"📊 Enhanced alpha stats - Mean: {enhanced_alpha.mean():.1f}, Min: {enhanced_alpha.min()}, Max: {enhanced_alpha.max()}")
436
+
437
+ return enhanced_alpha
438
+
439
+ except Exception as e:
440
+ import traceback
441
+ logger.error(f"❌ Cartoon mask enhancement failed: {e}")
442
+ logger.error(f"📍 Traceback: {traceback.format_exc()}")
443
+ print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
444
+ print(f"Traceback: {traceback.format_exc()}")
445
+ return alpha_mask
446
+
447
+ def _adjust_mask_for_scene_focus(self, mask: Image.Image, original_image: Image.Image) -> Image.Image:
448
+ """
449
+ Adjust mask for scene focus mode to include nearby objects like chairs, furniture
450
+ """
451
+ try:
452
+ logger.info("🏠 Adjusting mask for scene focus mode...")
453
+
454
+ mask_array = np.array(mask)
455
+ img_array = np.array(original_image.convert('RGB'))
456
+
457
+ # Expand mask to include nearby objects
458
+ # Use larger dilation kernel to include furniture/objects
459
+ kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
460
+ expanded_mask = cv2.dilate(mask_array, kernel_large, iterations=2)
461
+
462
+ # Find contours in the expanded area to detect objects
463
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
464
+ edges = cv2.Canny(gray, 30, 100)
465
+
466
+ # Apply edge detection only in the expanded region
467
+ expanded_region = (expanded_mask > 0) & (mask_array == 0)
468
+ object_edges = np.zeros_like(edges)
469
+ object_edges[expanded_region] = edges[expanded_region]
470
+
471
+ # Close gaps to form complete objects
472
+ kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
473
+ object_mask = cv2.morphologyEx(object_edges, cv2.MORPH_CLOSE, kernel_close)
474
+ object_mask = cv2.dilate(object_mask, kernel_close, iterations=1)
475
+
476
+ # Combine with original mask
477
+ final_mask = np.maximum(mask_array, object_mask)
478
+
479
+ logger.info("✅ Scene focus adjustment completed")
480
+ return Image.fromarray(final_mask)
481
+
482
+ except Exception as e:
483
+ logger.error(f"❌ Scene focus adjustment failed: {e}")
484
+ return mask
485
+
486
+ def create_gradient_based_mask(self, original_image: Image.Image, mode: str = "center", focus_mode: str = "person") -> Image.Image:
487
+ """
488
+ Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
489
+ Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
490
+ """
491
+ width, height = original_image.size
492
+ logger.info(f"🎯 Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}")
493
+
494
+ if mode == "center":
495
+ # Try using deep learning models for intelligent foreground extraction
496
+ logger.info("🤖 Attempting deep learning mask generation...")
497
+ dl_mask = self.try_deep_learning_mask(original_image)
498
+ if dl_mask is not None:
499
+ logger.info("✅ Using deep learning generated mask")
500
+ # Apply focus mode adjustments to deep learning mask
501
+ if focus_mode == "scene":
502
+ dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
503
+ return dl_mask
504
+
505
+ # Fallback to traditional method
506
+ logger.info("🔄 Deep learning failed, using traditional gradient-based method")
507
+ img_array = np.array(original_image.convert('RGB'))
508
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
509
+
510
+ # First-order derivatives: use Sobel operator for edge detection
511
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
512
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
513
+ gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
514
+
515
+ # Second-order derivatives: use Laplacian operator for texture change detection
516
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F, ksize=3)
517
+ laplacian_abs = np.abs(laplacian)
518
+
519
+ # Combine first and second order derivatives
520
+ combined_edges = gradient_magnitude * 0.7 + laplacian_abs * 0.3
521
+ combined_edges = (combined_edges / np.max(combined_edges)) * 255
522
+
523
+ # Threshold processing to find strong edges
524
+ _, edge_binary = cv2.threshold(combined_edges.astype(np.uint8), 20, 255, cv2.THRESH_BINARY)
525
+
526
+ # Morphological operations to connect edges
527
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
528
+ edge_binary = cv2.morphologyEx(edge_binary, cv2.MORPH_CLOSE, kernel)
529
+
530
+ # Find contours and create mask
531
+ contours, _ = cv2.findContours(edge_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
532
+
533
+ if contours:
534
+ # Find largest contour (main subject)
535
+ largest_contour = max(contours, key=cv2.contourArea)
536
+ contour_mask = np.zeros((height, width), dtype=np.uint8)
537
+ cv2.fillPoly(contour_mask, [largest_contour], 255)
538
+
539
+ # Create foreground enhancement mask: specially protect dark regions
540
+ dark_mask = (gray < 90).astype(np.uint8) * 255
541
+ morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
542
+ dark_mask = cv2.morphologyEx(dark_mask, cv2.MORPH_CLOSE, morph_kernel, iterations=1)
543
+ dark_mask = cv2.dilate(dark_mask, morph_kernel, iterations=2)
544
+ contour_mask = cv2.bitwise_or(contour_mask, dark_mask)
545
+
546
+ # Get core foreground: clean holes and fill gaps
547
+ close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
548
+ core_mask = cv2.morphologyEx(contour_mask, cv2.MORPH_CLOSE, close_kernel, iterations=1)
549
+
550
+ open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
551
+ core_mask = cv2.morphologyEx(core_mask, cv2.MORPH_OPEN, open_kernel, iterations=1)
552
+
553
+ # Convert to binary core (0/255)
554
+ _, core_binary = cv2.threshold(core_mask, 127, 255, cv2.THRESH_BINARY)
555
+
556
+ # Keep only slight dilation to avoid foreground being eaten
557
+ dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
558
+ core_binary = cv2.dilate(core_binary, dilate_kernel, iterations=1)
559
+
560
+ # Distance transform feathering: shrink feathering range for sharp edges
561
+ FEATHER_PX = 4
562
+
563
+ # Calculate distance transform
564
+ core_float = core_binary.astype(np.float32) / 255.0
565
+ distances = cv2.distanceTransform((1 - core_float).astype(np.uint8), cv2.DIST_L2, 5)
566
+
567
+ # Create feathering mask: 0→FEATHER_PX linear mapping to 1→0
568
+ feather_mask = np.ones_like(distances)
569
+ edge_region = (distances > 0) & (distances <= FEATHER_PX)
570
+ feather_mask[edge_region] = 1.0 - (distances[edge_region] / FEATHER_PX)
571
+ feather_mask[distances > FEATHER_PX] = 0.0
572
+
573
+ # Apply double-smoothstep curve: make transition steeper, reduce semi-transparent halos
574
+ def double_smoothstep(t):
575
+ t = np.clip(t, 0, 1)
576
+ s1 = t * t * (3 - 2 * t)
577
+ return s1 * s1 * (3 - 2 * s1) # Equivalent to t^3 (10 - 15t + 6t^2)
578
+
579
+ # Combine core with feathering: core area keeps 255, edges use double_smoothstep feathering
580
+ final_alpha = np.zeros_like(distances)
581
+ final_alpha[core_binary > 127] = 1.0 # Core area
582
+ final_alpha[edge_region] = double_smoothstep(feather_mask[edge_region]) # Feathering area
583
+
584
+ # Convert to 0-255 range
585
+ final_mask = (final_alpha * 255).astype(np.uint8)
586
+
587
+ # Apply guided filter for edge-preserving smoothing
588
+ final_mask = self.apply_guided_filter(final_mask, original_image, radius=8, eps=0.01)
589
+
590
+ mask = Image.fromarray(final_mask)
591
+ else:
592
+ # Backup plan: use large ellipse
593
+ mask = Image.new('L', (width, height), 0)
594
+ draw = ImageDraw.Draw(mask)
595
+ center_x, center_y = width // 2, height // 2
596
+ width_radius = int(width * 0.45)
597
+ height_radius = int(width * 0.48)
598
+ draw.ellipse([
599
+ center_x - width_radius, center_y - height_radius,
600
+ center_x + width_radius, center_y + height_radius
601
+ ], fill=255)
602
+ # Apply guided filter instead of Gaussian blur
603
+ mask_array = np.array(mask)
604
+ mask_array = self.apply_guided_filter(mask_array, original_image, radius=10, eps=0.02)
605
+ mask = Image.fromarray(mask_array)
606
+
607
+ elif mode == "left_half":
608
+ # Keep original logic unchanged - ensure Snoopy and other functions work normally
609
+ mask = Image.new('L', (width, height), 0)
610
+ mask_array = np.array(mask)
611
+ mask_array[:, :width//2] = 255
612
+
613
+ transition_zone = width // 10
614
+ for i in range(transition_zone):
615
+ x_pos = width//2 + i
616
+ if x_pos < width:
617
+ alpha = 255 * (1 - i / transition_zone)
618
+ mask_array[:, x_pos] = int(alpha)
619
+
620
+ mask = Image.fromarray(mask_array)
621
+
622
+ elif mode == "right_half":
623
+ # Keep original logic unchanged - ensure Snoopy and other functions work normally
624
+ mask = Image.new('L', (width, height), 0)
625
+ mask_array = np.array(mask)
626
+ mask_array[:, width//2:] = 255
627
+
628
+ transition_zone = width // 10
629
+ for i in range(transition_zone):
630
+ x_pos = width//2 - i - 1
631
+ if x_pos >= 0:
632
+ alpha = 255 * (1 - i / transition_zone)
633
+ mask_array[:, x_pos] = int(alpha)
634
+
635
+ mask = Image.fromarray(mask_array)
636
+
637
+ elif mode == "full":
638
+ mask = Image.new('L', (width, height), 0)
639
+ draw = ImageDraw.Draw(mask)
640
+ center_x, center_y = width // 2, height // 2
641
+ radius = min(width, height) // 8
642
+
643
+ draw.ellipse([
644
+ center_x - radius, center_y - radius,
645
+ center_x + radius, center_y + radius
646
+ ], fill=255)
647
+
648
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=5))
649
+
650
+ return mask
model_manager.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gc
3
+ import time
4
+ from typing import Dict, Any, Optional, Callable
5
+ from dataclasses import dataclass, field
6
+ from threading import Lock
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(logging.INFO)
11
+
12
+
13
+ @dataclass
14
+ class ModelInfo:
15
+ """Information about a registered model."""
16
+ name: str
17
+ loader: Callable[[], Any]
18
+ is_critical: bool = False # Critical models are not unloaded under memory pressure
19
+ estimated_memory_gb: float = 0.0
20
+ is_loaded: bool = False
21
+ last_used: float = 0.0
22
+ model_instance: Any = None
23
+
24
+
25
+ class ModelManager:
26
+ """
27
+ Singleton model manager for unified model lifecycle management.
28
+ Handles lazy loading, caching, and intelligent memory management.
29
+ """
30
+
31
+ _instance = None
32
+ _lock = Lock()
33
+
34
+ def __new__(cls):
35
+ if cls._instance is None:
36
+ with cls._lock:
37
+ if cls._instance is None:
38
+ cls._instance = super().__new__(cls)
39
+ cls._instance._initialized = False
40
+ return cls._instance
41
+
42
+ def __init__(self):
43
+ if self._initialized:
44
+ return
45
+
46
+ self._models: Dict[str, ModelInfo] = {}
47
+ self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage
48
+ self._device = self._detect_device()
49
+
50
+ logger.info(f"🔧 ModelManager initialized on {self._device}")
51
+ self._initialized = True
52
+
53
+ def _detect_device(self) -> str:
54
+ """Detect best available device."""
55
+ if torch.cuda.is_available():
56
+ return "cuda"
57
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
58
+ return "mps"
59
+ return "cpu"
60
+
61
+ def register_model(
62
+ self,
63
+ name: str,
64
+ loader: Callable[[], Any],
65
+ is_critical: bool = False,
66
+ estimated_memory_gb: float = 0.0
67
+ ):
68
+ """
69
+ Register a model for managed loading.
70
+
71
+ Args:
72
+ name: Unique model identifier
73
+ loader: Callable that returns the loaded model
74
+ is_critical: If True, model won't be unloaded under memory pressure
75
+ estimated_memory_gb: Estimated GPU memory usage in GB
76
+ """
77
+ if name in self._models:
78
+ logger.warning(f"⚠️ Model '{name}' already registered, updating")
79
+
80
+ self._models[name] = ModelInfo(
81
+ name=name,
82
+ loader=loader,
83
+ is_critical=is_critical,
84
+ estimated_memory_gb=estimated_memory_gb,
85
+ is_loaded=False,
86
+ last_used=0.0,
87
+ model_instance=None
88
+ )
89
+ logger.info(f"📝 Registered model: {name} (critical={is_critical}, ~{estimated_memory_gb:.1f}GB)")
90
+
91
+ def load_model(self, name: str) -> Any:
92
+ """
93
+ Load a model by name. Returns cached instance if already loaded.
94
+
95
+ Args:
96
+ name: Model identifier
97
+
98
+ Returns:
99
+ Loaded model instance
100
+
101
+ Raises:
102
+ KeyError: If model not registered
103
+ RuntimeError: If loading fails
104
+ """
105
+ if name not in self._models:
106
+ raise KeyError(f"Model '{name}' not registered")
107
+
108
+ model_info = self._models[name]
109
+
110
+ # Return cached instance
111
+ if model_info.is_loaded and model_info.model_instance is not None:
112
+ model_info.last_used = time.time()
113
+ logger.debug(f"📦 Using cached model: {name}")
114
+ return model_info.model_instance
115
+
116
+ # Check memory pressure before loading
117
+ self.check_memory_pressure()
118
+
119
+ # Load the model
120
+ try:
121
+ logger.info(f"📥 Loading model: {name}")
122
+ start_time = time.time()
123
+
124
+ model_instance = model_info.loader()
125
+
126
+ model_info.model_instance = model_instance
127
+ model_info.is_loaded = True
128
+ model_info.last_used = time.time()
129
+
130
+ load_time = time.time() - start_time
131
+ logger.info(f"✅ Model '{name}' loaded in {load_time:.1f}s")
132
+
133
+ return model_instance
134
+
135
+ except Exception as e:
136
+ logger.error(f"❌ Failed to load model '{name}': {e}")
137
+ raise RuntimeError(f"Model loading failed: {e}")
138
+
139
+ def unload_model(self, name: str):
140
+ """
141
+ Unload a specific model to free memory.
142
+
143
+ Args:
144
+ name: Model identifier
145
+ """
146
+ if name not in self._models:
147
+ return
148
+
149
+ model_info = self._models[name]
150
+
151
+ if not model_info.is_loaded:
152
+ return
153
+
154
+ try:
155
+ logger.info(f"🗑️ Unloading model: {name}")
156
+
157
+ # Delete model instance
158
+ if model_info.model_instance is not None:
159
+ del model_info.model_instance
160
+
161
+ model_info.model_instance = None
162
+ model_info.is_loaded = False
163
+
164
+ # Cleanup
165
+ gc.collect()
166
+ if torch.cuda.is_available():
167
+ torch.cuda.empty_cache()
168
+
169
+ logger.info(f"✅ Model '{name}' unloaded")
170
+
171
+ except Exception as e:
172
+ logger.error(f"❌ Error unloading model '{name}': {e}")
173
+
174
+ def check_memory_pressure(self) -> bool:
175
+ """
176
+ Check GPU memory usage and unload least-used non-critical models if needed.
177
+
178
+ Returns:
179
+ True if cleanup was performed
180
+ """
181
+ if not torch.cuda.is_available():
182
+ return False
183
+
184
+ allocated = torch.cuda.memory_allocated() / 1024**3
185
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
186
+ usage_ratio = allocated / total
187
+
188
+ if usage_ratio < self._memory_threshold:
189
+ return False
190
+
191
+ logger.warning(f"⚠️ Memory pressure detected: {usage_ratio:.1%} used")
192
+
193
+ # Find non-critical models sorted by last used time
194
+ unloadable = [
195
+ (name, info) for name, info in self._models.items()
196
+ if info.is_loaded and not info.is_critical
197
+ ]
198
+ unloadable.sort(key=lambda x: x[1].last_used)
199
+
200
+ # Unload oldest non-critical models
201
+ cleaned = False
202
+ for name, info in unloadable:
203
+ self.unload_model(name)
204
+ cleaned = True
205
+
206
+ # Re-check memory
207
+ new_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory
208
+ if new_ratio < self._memory_threshold * 0.7: # Target 70% of threshold
209
+ break
210
+
211
+ return cleaned
212
+
213
+ def force_cleanup(self):
214
+ """Force cleanup all non-critical models and clear caches."""
215
+ logger.info("🧹 Force cleanup initiated")
216
+
217
+ # Unload all non-critical models
218
+ for name, info in self._models.items():
219
+ if info.is_loaded and not info.is_critical:
220
+ self.unload_model(name)
221
+
222
+ # Aggressive garbage collection
223
+ for _ in range(5):
224
+ gc.collect()
225
+
226
+ if torch.cuda.is_available():
227
+ torch.cuda.empty_cache()
228
+ torch.cuda.ipc_collect()
229
+ torch.cuda.synchronize()
230
+
231
+ logger.info("✅ Force cleanup completed")
232
+
233
+ def get_memory_status(self) -> Dict[str, Any]:
234
+ """
235
+ Get detailed memory status.
236
+
237
+ Returns:
238
+ Dictionary with memory statistics
239
+ """
240
+ status = {
241
+ "device": self._device,
242
+ "models": {},
243
+ "total_estimated_gb": 0.0
244
+ }
245
+
246
+ # Model status
247
+ for name, info in self._models.items():
248
+ status["models"][name] = {
249
+ "loaded": info.is_loaded,
250
+ "critical": info.is_critical,
251
+ "estimated_gb": info.estimated_memory_gb,
252
+ "last_used": info.last_used
253
+ }
254
+ if info.is_loaded:
255
+ status["total_estimated_gb"] += info.estimated_memory_gb
256
+
257
+ # GPU memory
258
+ if torch.cuda.is_available():
259
+ allocated = torch.cuda.memory_allocated() / 1024**3
260
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
261
+ cached = torch.cuda.memory_reserved() / 1024**3
262
+
263
+ status["gpu"] = {
264
+ "allocated_gb": round(allocated, 2),
265
+ "total_gb": round(total, 2),
266
+ "cached_gb": round(cached, 2),
267
+ "free_gb": round(total - allocated, 2),
268
+ "usage_percent": round((allocated / total) * 100, 1)
269
+ }
270
+
271
+ return status
272
+
273
+ def get_loaded_models(self) -> list:
274
+ """Get list of currently loaded model names."""
275
+ return [name for name, info in self._models.items() if info.is_loaded]
276
+
277
+ def is_model_loaded(self, name: str) -> bool:
278
+ """Check if a specific model is loaded."""
279
+ if name not in self._models:
280
+ return False
281
+ return self._models[name].is_loaded
282
+
283
+
284
+ # Global singleton instance
285
+ _model_manager = None
286
+
287
+
288
+ def get_model_manager() -> ModelManager:
289
+ """Get the global ModelManager singleton instance."""
290
+ global _model_manager
291
+ if _model_manager is None:
292
+ _model_manager = ModelManager()
293
+ return _model_manager
quality_checker.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ from typing import Dict, Any, Tuple, Optional
6
+ from dataclasses import dataclass
7
+
8
+ logger = logging.getLogger(__name__)
9
+ logger.setLevel(logging.INFO)
10
+
11
+
12
+ @dataclass
13
+ class QualityResult:
14
+ """Result of a quality check."""
15
+ score: float # 0-100
16
+ passed: bool
17
+ issue: str
18
+ details: Dict[str, Any]
19
+
20
+
21
+ class QualityChecker:
22
+ """
23
+ Automated quality validation system for generated images.
24
+ Provides checks for mask coverage, edge continuity, and color harmony.
25
+ """
26
+
27
+ # Quality thresholds
28
+ THRESHOLD_PASS = 70
29
+ THRESHOLD_WARNING = 50
30
+
31
+ def __init__(self, strictness: str = "standard"):
32
+ """
33
+ Initialize QualityChecker.
34
+
35
+ Args:
36
+ strictness: Quality check strictness level
37
+ "lenient" - Only check fatal issues
38
+ "standard" - All checks with moderate thresholds
39
+ "strict" - High standards required
40
+ """
41
+ self.strictness = strictness
42
+ self._set_thresholds()
43
+
44
+ def _set_thresholds(self):
45
+ """Set quality thresholds based on strictness level."""
46
+ if self.strictness == "lenient":
47
+ self.min_coverage = 0.03 # 3%
48
+ self.min_edge_score = 40
49
+ self.min_harmony_score = 40
50
+ elif self.strictness == "strict":
51
+ self.min_coverage = 0.10 # 10%
52
+ self.min_edge_score = 75
53
+ self.min_harmony_score = 75
54
+ else: # standard
55
+ self.min_coverage = 0.05 # 5%
56
+ self.min_edge_score = 60
57
+ self.min_harmony_score = 60
58
+
59
+ def check_mask_coverage(self, mask: Image.Image) -> QualityResult:
60
+ """
61
+ Verify mask coverage is adequate.
62
+
63
+ Args:
64
+ mask: Grayscale mask image (L mode)
65
+
66
+ Returns:
67
+ QualityResult with coverage analysis
68
+ """
69
+ try:
70
+ mask_array = np.array(mask.convert('L'))
71
+ height, width = mask_array.shape
72
+ total_pixels = height * width
73
+
74
+ # Count foreground pixels
75
+ fg_pixels = np.count_nonzero(mask_array > 127)
76
+ coverage_ratio = fg_pixels / total_pixels
77
+
78
+ # Check for isolated small regions (noise)
79
+ _, binary = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY)
80
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8)
81
+
82
+ # Count significant regions (> 1% of image)
83
+ min_region_size = total_pixels * 0.01
84
+ significant_regions = sum(1 for i in range(1, num_labels)
85
+ if stats[i, cv2.CC_STAT_AREA] > min_region_size)
86
+
87
+ # Calculate fragmentation (many small regions = bad)
88
+ fragmentation_penalty = max(0, (num_labels - 1 - significant_regions) * 2)
89
+
90
+ # Score calculation
91
+ coverage_score = min(100, coverage_ratio * 200) # 50% coverage = 100 score
92
+ final_score = max(0, coverage_score - fragmentation_penalty)
93
+
94
+ # Determine pass/fail
95
+ passed = coverage_ratio >= self.min_coverage and significant_regions >= 1
96
+ issue = ""
97
+
98
+ if coverage_ratio < self.min_coverage:
99
+ issue = f"Low foreground coverage ({coverage_ratio:.1%})"
100
+ elif significant_regions == 0:
101
+ issue = "No significant foreground regions detected"
102
+ elif fragmentation_penalty > 20:
103
+ issue = f"Fragmented mask with {num_labels - 1} isolated regions"
104
+
105
+ return QualityResult(
106
+ score=final_score,
107
+ passed=passed,
108
+ issue=issue,
109
+ details={
110
+ "coverage_ratio": coverage_ratio,
111
+ "foreground_pixels": fg_pixels,
112
+ "total_regions": num_labels - 1,
113
+ "significant_regions": significant_regions
114
+ }
115
+ )
116
+
117
+ except Exception as e:
118
+ logger.error(f"❌ Mask coverage check failed: {e}")
119
+ return QualityResult(score=0, passed=False, issue=str(e), details={})
120
+
121
+ def check_edge_continuity(self, mask: Image.Image) -> QualityResult:
122
+ """
123
+ Check if mask edges are continuous and smooth.
124
+
125
+ Args:
126
+ mask: Grayscale mask image
127
+
128
+ Returns:
129
+ QualityResult with edge analysis
130
+ """
131
+ try:
132
+ mask_array = np.array(mask.convert('L'))
133
+
134
+ # Find edges using morphological gradient
135
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
136
+ gradient = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, kernel)
137
+
138
+ # Get edge pixels
139
+ edge_pixels = gradient > 20
140
+ edge_count = np.count_nonzero(edge_pixels)
141
+
142
+ if edge_count == 0:
143
+ return QualityResult(
144
+ score=50,
145
+ passed=False,
146
+ issue="No edges detected in mask",
147
+ details={"edge_count": 0}
148
+ )
149
+
150
+ # Check edge smoothness using Laplacian
151
+ laplacian = cv2.Laplacian(mask_array, cv2.CV_64F)
152
+ edge_laplacian = np.abs(laplacian[edge_pixels])
153
+
154
+ # High Laplacian values indicate jagged edges
155
+ smoothness = 100 - min(100, np.std(edge_laplacian) * 0.5)
156
+
157
+ # Check for gaps in edges
158
+ # Dilate and erode to find disconnections
159
+ dilated = cv2.dilate(gradient, kernel, iterations=1)
160
+ eroded = cv2.erode(dilated, kernel, iterations=1)
161
+ gaps = cv2.subtract(dilated, eroded)
162
+ gap_ratio = np.count_nonzero(gaps) / max(edge_count, 1)
163
+
164
+ # Calculate final score
165
+ gap_penalty = min(40, gap_ratio * 100)
166
+ final_score = max(0, smoothness - gap_penalty)
167
+
168
+ passed = final_score >= self.min_edge_score
169
+ issue = ""
170
+
171
+ if final_score < self.min_edge_score:
172
+ if smoothness < 60:
173
+ issue = "Jagged or rough edges detected"
174
+ elif gap_ratio > 0.3:
175
+ issue = "Discontinuous edges with gaps"
176
+ else:
177
+ issue = "Poor edge quality"
178
+
179
+ return QualityResult(
180
+ score=final_score,
181
+ passed=passed,
182
+ issue=issue,
183
+ details={
184
+ "edge_count": edge_count,
185
+ "smoothness": smoothness,
186
+ "gap_ratio": gap_ratio
187
+ }
188
+ )
189
+
190
+ except Exception as e:
191
+ logger.error(f"❌ Edge continuity check failed: {e}")
192
+ return QualityResult(score=0, passed=False, issue=str(e), details={})
193
+
194
+ def check_color_harmony(
195
+ self,
196
+ foreground: Image.Image,
197
+ background: Image.Image,
198
+ mask: Image.Image
199
+ ) -> QualityResult:
200
+ """
201
+ Evaluate color harmony between foreground and background.
202
+
203
+ Args:
204
+ foreground: Original foreground image
205
+ background: Generated background image
206
+ mask: Combination mask
207
+
208
+ Returns:
209
+ QualityResult with harmony analysis
210
+ """
211
+ try:
212
+ fg_array = np.array(foreground.convert('RGB'))
213
+ bg_array = np.array(background.convert('RGB'))
214
+ mask_array = np.array(mask.convert('L'))
215
+
216
+ # Get foreground and background regions
217
+ fg_region = mask_array > 127
218
+ bg_region = mask_array <= 127
219
+
220
+ if not np.any(fg_region) or not np.any(bg_region):
221
+ return QualityResult(
222
+ score=50,
223
+ passed=True,
224
+ issue="Cannot analyze harmony - insufficient regions",
225
+ details={}
226
+ )
227
+
228
+ # Convert to LAB for perceptual analysis
229
+ fg_lab = cv2.cvtColor(fg_array, cv2.COLOR_RGB2LAB).astype(np.float32)
230
+ bg_lab = cv2.cvtColor(bg_array, cv2.COLOR_RGB2LAB).astype(np.float32)
231
+
232
+ # Calculate average colors
233
+ fg_avg_l = np.mean(fg_lab[fg_region, 0])
234
+ fg_avg_a = np.mean(fg_lab[fg_region, 1])
235
+ fg_avg_b = np.mean(fg_lab[fg_region, 2])
236
+
237
+ bg_avg_l = np.mean(bg_lab[bg_region, 0])
238
+ bg_avg_a = np.mean(bg_lab[bg_region, 1])
239
+ bg_avg_b = np.mean(bg_lab[bg_region, 2])
240
+
241
+ # Calculate color differences
242
+ delta_l = abs(fg_avg_l - bg_avg_l)
243
+ delta_a = abs(fg_avg_a - bg_avg_a)
244
+ delta_b = abs(fg_avg_b - bg_avg_b)
245
+
246
+ # Overall color difference (Delta E approximation)
247
+ delta_e = np.sqrt(delta_l**2 + delta_a**2 + delta_b**2)
248
+
249
+ # Score calculation
250
+ # Moderate difference is good (20-60 Delta E)
251
+ # Too similar or too different is problematic
252
+ if delta_e < 10:
253
+ harmony_score = 60 # Too similar, foreground may get lost
254
+ issue = "Foreground and background colors too similar"
255
+ elif delta_e > 80:
256
+ harmony_score = 50 # Too different, may look unnatural
257
+ issue = "High color contrast may look unnatural"
258
+ elif 20 <= delta_e <= 60:
259
+ harmony_score = 100 # Ideal range
260
+ issue = ""
261
+ else:
262
+ harmony_score = 80
263
+ issue = ""
264
+
265
+ # Check for extreme contrast (very dark fg on very bright bg or vice versa)
266
+ brightness_contrast = abs(fg_avg_l - bg_avg_l)
267
+ if brightness_contrast > 100:
268
+ harmony_score = max(40, harmony_score - 30)
269
+ issue = "Extreme brightness contrast between foreground and background"
270
+
271
+ passed = harmony_score >= self.min_harmony_score
272
+
273
+ return QualityResult(
274
+ score=harmony_score,
275
+ passed=passed,
276
+ issue=issue,
277
+ details={
278
+ "delta_e": delta_e,
279
+ "delta_l": delta_l,
280
+ "delta_a": delta_a,
281
+ "delta_b": delta_b,
282
+ "fg_luminance": fg_avg_l,
283
+ "bg_luminance": bg_avg_l
284
+ }
285
+ )
286
+
287
+ except Exception as e:
288
+ logger.error(f"❌ Color harmony check failed: {e}")
289
+ return QualityResult(score=0, passed=False, issue=str(e), details={})
290
+
291
+ def run_all_checks(
292
+ self,
293
+ foreground: Image.Image,
294
+ background: Image.Image,
295
+ mask: Image.Image,
296
+ combined: Optional[Image.Image] = None
297
+ ) -> Dict[str, Any]:
298
+ """
299
+ Run all quality checks and return comprehensive results.
300
+
301
+ Args:
302
+ foreground: Original foreground image
303
+ background: Generated background
304
+ mask: Combination mask
305
+ combined: Final combined image (optional)
306
+
307
+ Returns:
308
+ Dictionary with all check results and overall score
309
+ """
310
+ logger.info("🔍 Running quality checks...")
311
+
312
+ results = {
313
+ "checks": {},
314
+ "overall_score": 0,
315
+ "passed": True,
316
+ "warnings": [],
317
+ "errors": []
318
+ }
319
+
320
+ # Run individual checks
321
+ coverage_result = self.check_mask_coverage(mask)
322
+ results["checks"]["mask_coverage"] = {
323
+ "score": coverage_result.score,
324
+ "passed": coverage_result.passed,
325
+ "issue": coverage_result.issue,
326
+ "details": coverage_result.details
327
+ }
328
+
329
+ edge_result = self.check_edge_continuity(mask)
330
+ results["checks"]["edge_continuity"] = {
331
+ "score": edge_result.score,
332
+ "passed": edge_result.passed,
333
+ "issue": edge_result.issue,
334
+ "details": edge_result.details
335
+ }
336
+
337
+ harmony_result = self.check_color_harmony(foreground, background, mask)
338
+ results["checks"]["color_harmony"] = {
339
+ "score": harmony_result.score,
340
+ "passed": harmony_result.passed,
341
+ "issue": harmony_result.issue,
342
+ "details": harmony_result.details
343
+ }
344
+
345
+ # Calculate overall score (weighted average)
346
+ weights = {
347
+ "mask_coverage": 0.4,
348
+ "edge_continuity": 0.3,
349
+ "color_harmony": 0.3
350
+ }
351
+
352
+ total_score = (
353
+ coverage_result.score * weights["mask_coverage"] +
354
+ edge_result.score * weights["edge_continuity"] +
355
+ harmony_result.score * weights["color_harmony"]
356
+ )
357
+ results["overall_score"] = round(total_score, 1)
358
+
359
+ # Determine overall pass/fail
360
+ results["passed"] = all([
361
+ coverage_result.passed,
362
+ edge_result.passed,
363
+ harmony_result.passed
364
+ ])
365
+
366
+ # Collect warnings and errors
367
+ for check_name, check_data in results["checks"].items():
368
+ if check_data["issue"]:
369
+ if check_data["passed"]:
370
+ results["warnings"].append(f"{check_name}: {check_data['issue']}")
371
+ else:
372
+ results["errors"].append(f"{check_name}: {check_data['issue']}")
373
+
374
+ logger.info(f"📊 Quality check complete - Score: {results['overall_score']}, Passed: {results['passed']}")
375
+
376
+ return results
377
+
378
+ def get_quality_summary(self, results: Dict[str, Any]) -> str:
379
+ """
380
+ Generate human-readable quality summary.
381
+
382
+ Args:
383
+ results: Results from run_all_checks
384
+
385
+ Returns:
386
+ Summary string
387
+ """
388
+ score = results["overall_score"]
389
+ passed = results["passed"]
390
+
391
+ if score >= 90:
392
+ grade = "Excellent"
393
+ elif score >= 75:
394
+ grade = "Good"
395
+ elif score >= 60:
396
+ grade = "Acceptable"
397
+ elif score >= 40:
398
+ grade = "Needs Improvement"
399
+ else:
400
+ grade = "Poor"
401
+
402
+ summary = f"Quality: {grade} ({score:.0f}/100)"
403
+
404
+ if results["errors"]:
405
+ summary += f"\nIssues: {'; '.join(results['errors'])}"
406
+ elif results["warnings"]:
407
+ summary += f"\nNotes: {'; '.join(results['warnings'])}"
408
+
409
+ return summary
requirements.txt ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SceneWeaver Hugging Face Spaces Deployment Requirements
2
+ # Optimized for ZeroGPU environment with safe version ranges
3
+
4
+ # ============================================
5
+ # Core Deep Learning Framework
6
+ # ============================================
7
+ # PyTorch 2.x series - compatible with SDXL and xformers
8
+ torch>=2.0.0,<2.5.0
9
+ torchvision>=0.15.0,<0.20.0
10
+ torchaudio>=2.0.0,<2.5.0
11
+
12
+ # ============================================
13
+ # Diffusion Models and Transformers
14
+ # ============================================
15
+ # Diffusers 0.25+ has better SDXL support, <0.32 for stability
16
+ diffusers>=0.25.0,<0.32.0
17
+ # Transformers compatible with diffusers and open_clip
18
+ transformers>=4.35.0,<4.46.0
19
+ # Accelerate for model loading optimizations
20
+ accelerate>=0.24.0,<0.35.0
21
+ # xformers for memory efficient attention (optional, may fail on some systems)
22
+ # xformers>=0.0.22,<0.0.29
23
+
24
+ # ============================================
25
+ # Computer Vision and Image Processing
26
+ # ============================================
27
+ # OpenCV for image processing
28
+ opencv-python>=4.8.0,<4.11.0
29
+ # opencv-contrib-python for guided filter (cv2.ximgproc)
30
+ opencv-contrib-python>=4.8.0,<4.11.0
31
+ # Pillow for image I/O
32
+ Pillow>=9.5.0,<11.0.0
33
+ # SciPy for scientific computing
34
+ scipy>=1.10.0,<1.15.0
35
+
36
+ # ============================================
37
+ # Background Removal
38
+ # ============================================
39
+ # rembg for foreground extraction (CPU version for compatibility)
40
+ rembg>=2.0.50,<2.1.0
41
+
42
+ # ============================================
43
+ # Multi-modal Understanding (CLIP)
44
+ # ============================================
45
+ # OpenCLIP for image analysis
46
+ open_clip_torch>=2.20.0,<2.27.0
47
+ # Sentence transformers (dependency)
48
+ sentence-transformers>=2.2.0,<3.1.0
49
+
50
+ # ============================================
51
+ # Web Interface
52
+ # ============================================
53
+ # Gradio 4.x for modern UI
54
+ gradio>=4.0.0,<5.0.0
55
+
56
+ # ============================================
57
+ # Core Scientific Computing
58
+ # ============================================
59
+ # NumPy 1.x for compatibility
60
+ numpy>=1.24.0,<2.0.0
61
+
62
+ # ============================================
63
+ # Hugging Face Integration
64
+ # ============================================
65
+ # HuggingFace Hub for model downloads
66
+ huggingface_hub>=0.19.0,<0.27.0
67
+ # Safetensors for efficient model loading
68
+ safetensors>=0.4.0,<0.5.0
69
+
70
+ # ============================================
71
+ # System Utilities
72
+ # ============================================
73
+ # psutil for memory monitoring
74
+ psutil>=5.9.0,<6.1.0
75
+ # requests for HTTP operations
76
+ requests>=2.28.0,<2.33.0
77
+
78
+ # ============================================
79
+ # Hugging Face Spaces (auto-installed on Spaces)
80
+ # ============================================
81
+ # spaces # ZeroGPU support - auto-available on HF Spaces
scene_templates.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, List, Optional
3
+ from dataclasses import dataclass
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ @dataclass
9
+ class SceneTemplate:
10
+ """Data class representing a scene template"""
11
+ key: str
12
+ name: str
13
+ prompt: str
14
+ negative_extra: str
15
+ category: str
16
+ icon: str
17
+ guidance_scale: float = 7.5
18
+
19
+
20
+ class SceneTemplateManager:
21
+ """
22
+ Manages curated scene templates for background generation.
23
+ Provides categorized presets that users can select with one click.
24
+ """
25
+
26
+ # Scene template definitions
27
+ TEMPLATES: Dict[str, SceneTemplate] = {
28
+ # === Professional Category ===
29
+ "office_modern": SceneTemplate(
30
+ key="office_modern",
31
+ name="Modern Office",
32
+ prompt="modern minimalist office interior, clean white desk, large floor-to-ceiling windows, natural daylight, professional corporate environment, soft shadows, contemporary furniture",
33
+ negative_extra="messy, cluttered, dark, old",
34
+ category="Professional",
35
+ icon="🏢",
36
+ guidance_scale=7.5
37
+ ),
38
+ "office_executive": SceneTemplate(
39
+ key="office_executive",
40
+ name="Executive Suite",
41
+ prompt="luxurious executive office, mahogany desk, leather chair, city skyline view through windows, warm ambient lighting, bookshelf, elegant professional setting",
42
+ negative_extra="cheap, cramped, messy",
43
+ category="Professional",
44
+ icon="👔",
45
+ guidance_scale=7.5
46
+ ),
47
+ "studio_white": SceneTemplate(
48
+ key="studio_white",
49
+ name="White Studio",
50
+ prompt="clean white photography studio background, professional lighting setup, seamless white backdrop, soft diffused light, minimal shadows",
51
+ negative_extra="colored, textured, dirty",
52
+ category="Professional",
53
+ icon="📷",
54
+ guidance_scale=8.0
55
+ ),
56
+ "coworking": SceneTemplate(
57
+ key="coworking",
58
+ name="Coworking Space",
59
+ prompt="modern coworking space, open plan office, plants, exposed brick, industrial chic design, natural light, collaborative environment",
60
+ negative_extra="empty, dark, boring",
61
+ category="Professional",
62
+ icon="💼",
63
+ guidance_scale=7.0
64
+ ),
65
+ "conference": SceneTemplate(
66
+ key="conference",
67
+ name="Conference Room",
68
+ prompt="modern conference room, large meeting table, glass walls, professional presentation screen, bright corporate lighting, clean minimal design",
69
+ negative_extra="small, cramped, outdated",
70
+ category="Professional",
71
+ icon="🤝",
72
+ guidance_scale=7.5
73
+ ),
74
+
75
+ # === Nature Category ===
76
+ "beach_sunset": SceneTemplate(
77
+ key="beach_sunset",
78
+ name="Sunset Beach",
79
+ prompt="beautiful tropical beach at golden hour sunset, palm trees silhouette, calm turquoise ocean waves, warm orange and pink sky, soft sand, paradise vacation vibes",
80
+ negative_extra="storm, rain, crowded, trash",
81
+ category="Nature",
82
+ icon="🏖️",
83
+ guidance_scale=7.0
84
+ ),
85
+ "forest_enchanted": SceneTemplate(
86
+ key="forest_enchanted",
87
+ name="Enchanted Forest",
88
+ prompt="magical enchanted forest, sunlight streaming through tall trees, lush green foliage, mystical atmosphere, morning mist, fairy tale woodland",
89
+ negative_extra="dead trees, dark, scary, barren",
90
+ category="Nature",
91
+ icon="🌲",
92
+ guidance_scale=7.0
93
+ ),
94
+ "mountain_scenic": SceneTemplate(
95
+ key="mountain_scenic",
96
+ name="Mountain Vista",
97
+ prompt="breathtaking mountain landscape, snow-capped peaks, alpine meadow, clear blue sky, majestic scenic view, pristine nature, peaceful atmosphere",
98
+ negative_extra="industrial, polluted, crowded",
99
+ category="Nature",
100
+ icon="🏔️",
101
+ guidance_scale=7.5
102
+ ),
103
+ "garden_spring": SceneTemplate(
104
+ key="garden_spring",
105
+ name="Spring Garden",
106
+ prompt="beautiful spring flower garden, colorful blooming flowers, roses and tulips, manicured hedges, sunny day, botanical paradise, fresh and vibrant",
107
+ negative_extra="dead, winter, wilted, dry",
108
+ category="Nature",
109
+ icon="🌸",
110
+ guidance_scale=7.0
111
+ ),
112
+ "lake_serene": SceneTemplate(
113
+ key="lake_serene",
114
+ name="Serene Lake",
115
+ prompt="peaceful serene lake at dawn, mirror-like water reflection, surrounding mountains, soft morning light, tranquil atmosphere, pristine natural beauty",
116
+ negative_extra="stormy, polluted, industrial",
117
+ category="Nature",
118
+ icon="🏞️",
119
+ guidance_scale=7.0
120
+ ),
121
+ "cherry_blossom": SceneTemplate(
122
+ key="cherry_blossom",
123
+ name="Cherry Blossom",
124
+ prompt="stunning cherry blossom trees in full bloom, pink sakura petals falling gently, Japanese garden aesthetic, soft spring sunlight, romantic atmosphere",
125
+ negative_extra="winter, dead, brown, wilted",
126
+ category="Nature",
127
+ icon="🌸",
128
+ guidance_scale=7.0
129
+ ),
130
+
131
+ # === Urban Category ===
132
+ "city_skyline": SceneTemplate(
133
+ key="city_skyline",
134
+ name="City Skyline",
135
+ prompt="modern city skyline at blue hour, impressive skyscrapers, glass buildings reflecting sunset, urban metropolitan view, cinematic atmosphere",
136
+ negative_extra="slums, dirty, abandoned, ruins",
137
+ category="Urban",
138
+ icon="🌆",
139
+ guidance_scale=7.5
140
+ ),
141
+ "cafe_cozy": SceneTemplate(
142
+ key="cafe_cozy",
143
+ name="Cozy Cafe",
144
+ prompt="warm cozy coffee shop interior, wooden furniture, ambient lighting, exposed brick walls, plants, comfortable atmosphere, artisan cafe vibes",
145
+ negative_extra="fast food, plastic, harsh lighting",
146
+ category="Urban",
147
+ icon="☕",
148
+ guidance_scale=7.0
149
+ ),
150
+ "street_european": SceneTemplate(
151
+ key="street_european",
152
+ name="European Street",
153
+ prompt="charming European cobblestone street, historic buildings, outdoor cafe, flowers on balconies, warm afternoon light, romantic Paris or Rome vibes",
154
+ negative_extra="modern, industrial, ugly, dirty",
155
+ category="Urban",
156
+ icon="🏛️",
157
+ guidance_scale=7.0
158
+ ),
159
+ "night_neon": SceneTemplate(
160
+ key="night_neon",
161
+ name="Neon Nightlife",
162
+ prompt="vibrant city nightlife scene, neon lights and signs, urban night atmosphere, colorful reflections on wet street, cyberpunk aesthetic, electric energy",
163
+ negative_extra="daytime, boring, plain",
164
+ category="Urban",
165
+ icon="🌃",
166
+ guidance_scale=6.5
167
+ ),
168
+ "rooftop_view": SceneTemplate(
169
+ key="rooftop_view",
170
+ name="Rooftop Terrace",
171
+ prompt="luxury rooftop terrace, city panoramic view, modern outdoor furniture, string lights, sunset golden hour, sophisticated urban oasis",
172
+ negative_extra="cheap, dirty, crowded",
173
+ category="Urban",
174
+ icon="🏙️",
175
+ guidance_scale=7.5
176
+ ),
177
+
178
+ # === Artistic Category ===
179
+ "gradient_soft": SceneTemplate(
180
+ key="gradient_soft",
181
+ name="Soft Gradient",
182
+ prompt="smooth soft gradient background, pastel colors blending beautifully, pink to blue to purple transition, dreamy aesthetic, professional portrait backdrop",
183
+ negative_extra="harsh, noisy, textured, busy",
184
+ category="Artistic",
185
+ icon="🎨",
186
+ guidance_scale=8.0
187
+ ),
188
+ "abstract_modern": SceneTemplate(
189
+ key="abstract_modern",
190
+ name="Modern Abstract",
191
+ prompt="modern abstract art background, geometric shapes, bold colors, contemporary design, artistic composition, museum gallery aesthetic",
192
+ negative_extra="realistic, plain, boring",
193
+ category="Artistic",
194
+ icon="🖼️",
195
+ guidance_scale=6.5
196
+ ),
197
+ "vintage_retro": SceneTemplate(
198
+ key="vintage_retro",
199
+ name="Vintage Retro",
200
+ prompt="vintage retro aesthetic background, warm sepia tones, nostalgic 70s vibes, film grain texture, classic photography style, timeless elegance",
201
+ negative_extra="modern, digital, cold, harsh",
202
+ category="Artistic",
203
+ icon="📻",
204
+ guidance_scale=7.0
205
+ ),
206
+ "watercolor_dream": SceneTemplate(
207
+ key="watercolor_dream",
208
+ name="Watercolor Dream",
209
+ prompt="beautiful watercolor painting background, soft flowing colors, artistic brush strokes, dreamy ethereal atmosphere, delicate artistic aesthetic",
210
+ negative_extra="digital, sharp, photorealistic",
211
+ category="Artistic",
212
+ icon="🖌️",
213
+ guidance_scale=6.5
214
+ ),
215
+
216
+ # === Seasonal Category ===
217
+ "autumn_foliage": SceneTemplate(
218
+ key="autumn_foliage",
219
+ name="Autumn Foliage",
220
+ prompt="beautiful autumn scenery, vibrant fall foliage, orange red and golden leaves, maple trees, warm sunlight filtering through, cozy seasonal atmosphere",
221
+ negative_extra="spring, summer, green, snow",
222
+ category="Seasonal",
223
+ icon="🍂",
224
+ guidance_scale=7.0
225
+ ),
226
+ "winter_snow": SceneTemplate(
227
+ key="winter_snow",
228
+ name="Winter Wonderland",
229
+ prompt="magical winter wonderland, fresh white snow covering everything, snow-laden pine trees, soft snowfall, peaceful cold atmosphere, holiday season vibes",
230
+ negative_extra="summer, green, rain, mud",
231
+ category="Seasonal",
232
+ icon="❄️",
233
+ guidance_scale=7.0
234
+ ),
235
+ "summer_tropical": SceneTemplate(
236
+ key="summer_tropical",
237
+ name="Tropical Summer",
238
+ prompt="vibrant tropical summer scene, lush palm trees, bright sunny day, exotic flowers, paradise vacation destination, warm and inviting atmosphere",
239
+ negative_extra="winter, cold, snow, gray",
240
+ category="Seasonal",
241
+ icon="🌴",
242
+ guidance_scale=7.0
243
+ ),
244
+ "spring_meadow": SceneTemplate(
245
+ key="spring_meadow",
246
+ name="Spring Meadow",
247
+ prompt="beautiful spring meadow, wildflowers blooming, fresh green grass, butterflies, soft warm sunlight, renewal and new beginnings, pastoral beauty",
248
+ negative_extra="winter, autumn, dead, dry",
249
+ category="Seasonal",
250
+ icon="🌷",
251
+ guidance_scale=7.0
252
+ ),
253
+ }
254
+
255
+ # Category display order
256
+ CATEGORIES = ["Professional", "Nature", "Urban", "Artistic", "Seasonal"]
257
+
258
+ def __init__(self):
259
+ """Initialize the scene template manager"""
260
+ logger.info(f"SceneTemplateManager initialized with {len(self.TEMPLATES)} templates")
261
+
262
+ def get_all_templates(self) -> Dict[str, SceneTemplate]:
263
+ """Get all available templates"""
264
+ return self.TEMPLATES
265
+
266
+ def get_template(self, key: str) -> Optional[SceneTemplate]:
267
+ """Get a specific template by key"""
268
+ return self.TEMPLATES.get(key)
269
+
270
+ def get_templates_by_category(self, category: str) -> List[SceneTemplate]:
271
+ """Get all templates in a specific category"""
272
+ return [t for t in self.TEMPLATES.values() if t.category == category]
273
+
274
+ def get_categories(self) -> List[str]:
275
+ """Get list of all categories in display order"""
276
+ return self.CATEGORIES
277
+
278
+ def get_template_choices_sorted(self) -> List[str]:
279
+ """
280
+ Get template choices formatted for Gradio dropdown.
281
+ Returns list of display strings sorted A-Z: "🏢 Modern Office"
282
+ """
283
+ display_list = []
284
+ for key, template in self.TEMPLATES.items():
285
+ display_name = f"{template.icon} {template.name}"
286
+ display_list.append(display_name)
287
+
288
+ # Sort alphabetically by name (ignoring emoji)
289
+ display_list.sort(key=lambda x: x.split(' ', 1)[1] if ' ' in x else x)
290
+ return display_list
291
+
292
+ def get_template_key_from_display(self, display_name: str) -> Optional[str]:
293
+ """
294
+ Get template key from display name.
295
+ Example: "🏢 Modern Office" -> "office_modern"
296
+ """
297
+ if not display_name:
298
+ return None
299
+
300
+ for key, template in self.TEMPLATES.items():
301
+ if f"{template.icon} {template.name}" == display_name:
302
+ return key
303
+ return None
304
+
305
+ def get_prompt_for_template(self, key: str) -> Optional[str]:
306
+ """Get the prompt string for a template"""
307
+ template = self.get_template(key)
308
+ return template.prompt if template else None
309
+
310
+ def get_negative_prompt_for_template(
311
+ self,
312
+ key: str,
313
+ base_negative: str = "blurry, low quality, distorted, people, characters"
314
+ ) -> str:
315
+ """Get combined negative prompt for a template"""
316
+ template = self.get_template(key)
317
+ if template and template.negative_extra:
318
+ return f"{base_negative}, {template.negative_extra}"
319
+ return base_negative
320
+
321
+ def get_guidance_scale_for_template(self, key: str) -> float:
322
+ """Get the recommended guidance scale for a template"""
323
+ template = self.get_template(key)
324
+ return template.guidance_scale if template else 7.5
325
+
326
+ def build_gallery_html(self) -> str:
327
+ """
328
+ Build HTML for the scene template gallery.
329
+ Returns HTML string for display in Gradio.
330
+ """
331
+ html_parts = ['<div class="scene-gallery">']
332
+
333
+ for category in self.CATEGORIES:
334
+ templates = self.get_templates_by_category(category)
335
+ if not templates:
336
+ continue
337
+
338
+ html_parts.append(f'''
339
+ <div class="scene-category">
340
+ <h4 class="scene-category-title">{category}</h4>
341
+ <div class="scene-grid">
342
+ ''')
343
+
344
+ for template in templates:
345
+ html_parts.append(f'''
346
+ <button class="scene-card" data-template="{template.key}" onclick="selectTemplate('{template.key}')">
347
+ <span class="scene-icon">{template.icon}</span>
348
+ <span class="scene-name">{template.name}</span>
349
+ </button>
350
+ ''')
351
+
352
+ html_parts.append('</div></div>')
353
+
354
+ html_parts.append('</div>')
355
+ return ''.join(html_parts)
356
+
357
+ def get_gallery_css(self) -> str:
358
+ """Get CSS styles for the scene gallery"""
359
+ return """
360
+ /* Scene Gallery Styles */
361
+ .scene-gallery {
362
+ margin: 16px 0;
363
+ }
364
+
365
+ .scene-category {
366
+ margin-bottom: 20px;
367
+ }
368
+
369
+ .scene-category-title {
370
+ font-size: 0.9rem;
371
+ font-weight: 600;
372
+ color: #475569;
373
+ margin-bottom: 12px;
374
+ padding-bottom: 8px;
375
+ border-bottom: 1px solid #e2e8f0;
376
+ }
377
+
378
+ .scene-grid {
379
+ display: grid;
380
+ grid-template-columns: repeat(auto-fill, minmax(100px, 1fr));
381
+ gap: 8px;
382
+ }
383
+
384
+ .scene-card {
385
+ display: flex;
386
+ flex-direction: column;
387
+ align-items: center;
388
+ justify-content: center;
389
+ padding: 12px 8px;
390
+ background: #f8fafc;
391
+ border: 1px solid #e2e8f0;
392
+ border-radius: 8px;
393
+ cursor: pointer;
394
+ transition: all 0.2s ease;
395
+ min-height: 70px;
396
+ }
397
+
398
+ .scene-card:hover {
399
+ background: #dbeafe;
400
+ border-color: #3b82f6;
401
+ transform: translateY(-2px);
402
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
403
+ }
404
+
405
+ .scene-card.selected {
406
+ background: #dbeafe;
407
+ border-color: #3b82f6;
408
+ box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.3);
409
+ }
410
+
411
+ .scene-icon {
412
+ font-size: 1.5rem;
413
+ margin-bottom: 4px;
414
+ }
415
+
416
+ .scene-name {
417
+ font-size: 0.75rem;
418
+ font-weight: 500;
419
+ color: #1e293b;
420
+ text-align: center;
421
+ line-height: 1.2;
422
+ }
423
+
424
+ @media (max-width: 768px) {
425
+ .scene-grid {
426
+ grid-template-columns: repeat(3, 1fr);
427
+ }
428
+ }
429
+ """
scene_weaver_core.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ import logging
6
+ import gc
7
+ import time
8
+ from typing import Optional, Dict, Any, Tuple, List
9
+ from pathlib import Path
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
14
+ import open_clip
15
+ from mask_generator import MaskGenerator
16
+ from image_blender import ImageBlender
17
+ from quality_checker import QualityChecker
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
+
22
+ class SceneWeaverCore:
23
+ """
24
+ SceneWeaver with perfect background generation + fixed blending + memory optimization
25
+ """
26
+
27
+ # Style presets for diversity generation mode
28
+ STYLE_PRESETS = {
29
+ "professional": {
30
+ "name": "Professional Business",
31
+ "modifier": "professional office environment, clean background, corporate setting, bright even lighting",
32
+ "negative_extra": "casual, messy, cluttered",
33
+ "guidance_scale": 8.0
34
+ },
35
+ "casual": {
36
+ "name": "Casual Lifestyle",
37
+ "modifier": "casual outdoor setting, natural environment, relaxed atmosphere, warm natural lighting",
38
+ "negative_extra": "formal, studio",
39
+ "guidance_scale": 7.5
40
+ },
41
+ "artistic": {
42
+ "name": "Artistic Creative",
43
+ "modifier": "artistic background, creative composition, vibrant colors, interesting lighting",
44
+ "negative_extra": "boring, plain",
45
+ "guidance_scale": 6.5
46
+ },
47
+ "nature": {
48
+ "name": "Natural Scenery",
49
+ "modifier": "beautiful natural scenery, outdoor landscape, scenic view, natural lighting",
50
+ "negative_extra": "urban, indoor",
51
+ "guidance_scale": 7.5
52
+ }
53
+ }
54
+
55
+ def __init__(self, device: str = "auto"):
56
+ self.device = self._setup_device(device)
57
+
58
+ # Model configurations - KEEP SAME FOR PERFECT GENERATION
59
+ self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ self.clip_model_name = "ViT-B-32"
61
+ self.clip_pretrained = "openai"
62
+
63
+ # Pipeline objects
64
+ self.pipeline = None
65
+ self.clip_model = None
66
+ self.clip_preprocess = None
67
+ self.clip_tokenizer = None
68
+ self.is_initialized = False
69
+
70
+ # Generation settings - KEEP SAME
71
+ self.max_image_size = 1024
72
+ self.default_steps = 25
73
+ self.use_fp16 = True
74
+
75
+ # Enhanced memory management
76
+ self.generation_count = 0
77
+ self.cleanup_frequency = 1 # More frequent cleanup
78
+ self.max_history = 3 # Limit generation history
79
+
80
+ # Initialize helper classes
81
+ self.mask_generator = MaskGenerator(self.max_image_size)
82
+ self.image_blender = ImageBlender()
83
+ self.quality_checker = QualityChecker()
84
+
85
+ logger.info(f"OptimizedSceneWeaver initialized on {self.device}")
86
+
87
+ def _setup_device(self, device: str) -> str:
88
+ """Setup computation device"""
89
+ if device == "auto":
90
+ if torch.cuda.is_available():
91
+ return "cuda"
92
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
93
+ return "mps"
94
+ else:
95
+ return "cpu"
96
+ return device
97
+
98
+ def _ultra_memory_cleanup(self):
99
+ """Ultra aggressive memory cleanup for Colab stability"""
100
+ logger.debug("🧹 Ultra memory cleanup...")
101
+
102
+ # Multiple rounds of garbage collection
103
+ for i in range(5):
104
+ gc.collect()
105
+
106
+ if torch.cuda.is_available():
107
+ # Clear all cached memory
108
+ torch.cuda.empty_cache()
109
+ torch.cuda.ipc_collect()
110
+
111
+ # Force synchronization
112
+ torch.cuda.synchronize()
113
+
114
+ # Clear any remaining memory fragments
115
+ try:
116
+ torch.cuda.memory.empty_cache()
117
+ except:
118
+ pass
119
+
120
+ logger.debug("✅ Ultra cleanup completed")
121
+
122
+ def load_models(self, progress_callback: Optional[callable] = None):
123
+ """Load AI models - KEEP SAME FOR PERFECT GENERATION"""
124
+ if self.is_initialized:
125
+ logger.info("Models already loaded")
126
+ return
127
+
128
+ logger.info("📥 Loading AI models...")
129
+
130
+ try:
131
+ self._ultra_memory_cleanup()
132
+
133
+ if progress_callback:
134
+ progress_callback("Loading OpenCLIP for image understanding...", 20)
135
+
136
+ # Load OpenCLIP - KEEP SAME
137
+ self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
138
+ self.clip_model_name,
139
+ pretrained=self.clip_pretrained,
140
+ device=self.device
141
+ )
142
+ self.clip_tokenizer = open_clip.get_tokenizer(self.clip_model_name)
143
+ self.clip_model.eval()
144
+
145
+ logger.info("✅ OpenCLIP loaded")
146
+
147
+ if progress_callback:
148
+ progress_callback("Loading SDXL text-to-image pipeline...", 60)
149
+
150
+ # Load standard SDXL text-to-image pipeline - KEEP SAME
151
+ self.pipeline = StableDiffusionXLPipeline.from_pretrained(
152
+ self.base_model_id,
153
+ torch_dtype=torch.float16 if self.use_fp16 else torch.float32,
154
+ use_safetensors=True,
155
+ variant="fp16" if self.use_fp16 else None
156
+ )
157
+
158
+ # Use DPM solver for faster generation - KEEP SAME
159
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
160
+ self.pipeline.scheduler.config
161
+ )
162
+
163
+ # Move to device
164
+ self.pipeline = self.pipeline.to(self.device)
165
+
166
+ if progress_callback:
167
+ progress_callback("Applying optimizations...", 90)
168
+
169
+ # Memory optimizations - ENHANCED
170
+ try:
171
+ self.pipeline.enable_xformers_memory_efficient_attention()
172
+ logger.info("✅ xformers enabled")
173
+ except Exception:
174
+ try:
175
+ self.pipeline.enable_attention_slicing()
176
+ logger.info("✅ Attention slicing enabled")
177
+ except Exception:
178
+ logger.warning("⚠️ No memory optimizations available")
179
+
180
+ # Additional memory optimizations
181
+ if hasattr(self.pipeline, 'enable_vae_tiling'):
182
+ self.pipeline.enable_vae_tiling()
183
+
184
+ if hasattr(self.pipeline, 'enable_vae_slicing'):
185
+ self.pipeline.enable_vae_slicing()
186
+
187
+ # Set to eval mode
188
+ self.pipeline.unet.eval()
189
+ if hasattr(self.pipeline, 'vae'):
190
+ self.pipeline.vae.eval()
191
+
192
+ # Enable sequential CPU offload if very low on memory
193
+ try:
194
+ if torch.cuda.is_available():
195
+ free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
196
+ if free_memory < 4 * 1024**3: # Less than 4GB free
197
+ self.pipeline.enable_sequential_cpu_offload()
198
+ logger.info("✅ Sequential CPU offload enabled for low memory")
199
+ except:
200
+ pass
201
+
202
+ self.is_initialized = True
203
+
204
+ if progress_callback:
205
+ progress_callback("Models loaded successfully!", 100)
206
+
207
+ # Memory status
208
+ if torch.cuda.is_available():
209
+ memory_used = torch.cuda.memory_allocated() / 1024**3
210
+ memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
211
+ logger.info(f"📊 GPU Memory: {memory_used:.1f}GB / {memory_total:.1f}GB")
212
+
213
+ except Exception as e:
214
+ logger.error(f"❌ Model loading failed: {e}")
215
+ raise RuntimeError(f"Failed to load models: {str(e)}")
216
+
217
+ def analyze_image_with_clip(self, image: Image.Image) -> str:
218
+ """Analyze uploaded image using OpenCLIP - KEEP SAME"""
219
+ if not self.clip_model:
220
+ return "Image analysis not available"
221
+
222
+ try:
223
+ image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
224
+
225
+ categories = [
226
+ "a photo of a person",
227
+ "a photo of an animal",
228
+ "a photo of an object",
229
+ "a photo of a character",
230
+ "a photo of a cartoon",
231
+ "a photo of nature",
232
+ "a photo of a building",
233
+ "a photo of a landscape"
234
+ ]
235
+
236
+ text_inputs = self.clip_tokenizer(categories).to(self.device)
237
+
238
+ with torch.no_grad():
239
+ image_features = self.clip_model.encode_image(image_input)
240
+ text_features = self.clip_model.encode_text(text_inputs)
241
+
242
+ image_features /= image_features.norm(dim=-1, keepdim=True)
243
+ text_features /= text_features.norm(dim=-1, keepdim=True)
244
+
245
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
246
+
247
+ best_match_idx = similarity.argmax().item()
248
+ confidence = similarity[0, best_match_idx].item()
249
+
250
+ category = categories[best_match_idx].replace("a photo of ", "")
251
+
252
+ return f"Detected: {category} (confidence: {confidence:.1%})"
253
+
254
+ except Exception as e:
255
+ logger.error(f"CLIP analysis failed: {e}")
256
+ return "Image analysis failed"
257
+
258
+ def enhance_prompt(
259
+ self,
260
+ user_prompt: str,
261
+ foreground_image: Image.Image
262
+ ) -> str:
263
+ """
264
+ Smart prompt enhancement based on image analysis.
265
+ Adds appropriate lighting, atmosphere, and quality descriptors.
266
+
267
+ Args:
268
+ user_prompt: Original user-provided prompt
269
+ foreground_image: Foreground image for analysis
270
+
271
+ Returns:
272
+ Enhanced prompt string
273
+ """
274
+ logger.info("✨ Enhancing prompt based on image analysis...")
275
+
276
+ try:
277
+ # Analyze image characteristics
278
+ img_array = np.array(foreground_image.convert('RGB'))
279
+
280
+ # === Analyze color temperature ===
281
+ # Convert to LAB to analyze color temperature
282
+ lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
283
+ avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
284
+ avg_b = np.mean(lab[:, :, 2]) # b channel: blue(-) to yellow(+)
285
+
286
+ # Determine warm/cool tone
287
+ is_warm = avg_b > 128 # b > 128 means more yellow/warm
288
+
289
+ # === Analyze brightness ===
290
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
291
+ avg_brightness = np.mean(gray)
292
+ is_bright = avg_brightness > 127
293
+
294
+ # === Get subject type from CLIP ===
295
+ clip_analysis = self.analyze_image_with_clip(foreground_image)
296
+ subject_type = "unknown"
297
+
298
+ if "person" in clip_analysis.lower():
299
+ subject_type = "person"
300
+ elif "animal" in clip_analysis.lower():
301
+ subject_type = "animal"
302
+ elif "object" in clip_analysis.lower():
303
+ subject_type = "object"
304
+ elif "character" in clip_analysis.lower() or "cartoon" in clip_analysis.lower():
305
+ subject_type = "character"
306
+ elif "nature" in clip_analysis.lower() or "landscape" in clip_analysis.lower():
307
+ subject_type = "nature"
308
+
309
+ # === Build prompt fragments library ===
310
+ lighting_options = {
311
+ "warm_bright": "warm golden hour lighting, soft natural light",
312
+ "warm_dark": "warm ambient lighting, cozy atmosphere",
313
+ "cool_bright": "bright daylight, clear sky lighting",
314
+ "cool_dark": "soft diffused light, gentle shadows"
315
+ }
316
+
317
+ atmosphere_options = {
318
+ "person": "professional, elegant composition",
319
+ "animal": "natural, harmonious setting",
320
+ "object": "clean product photography style",
321
+ "character": "artistic, vibrant, imaginative",
322
+ "nature": "scenic, peaceful atmosphere",
323
+ "unknown": "balanced composition"
324
+ }
325
+
326
+ quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
327
+
328
+ # === Select appropriate fragments ===
329
+ # Lighting based on color temperature and brightness
330
+ if is_warm and is_bright:
331
+ lighting = lighting_options["warm_bright"]
332
+ elif is_warm and not is_bright:
333
+ lighting = lighting_options["warm_dark"]
334
+ elif not is_warm and is_bright:
335
+ lighting = lighting_options["cool_bright"]
336
+ else:
337
+ lighting = lighting_options["cool_dark"]
338
+
339
+ # Atmosphere based on subject type
340
+ atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
341
+
342
+ # === Check for conflicts in user prompt ===
343
+ user_prompt_lower = user_prompt.lower()
344
+
345
+ # Avoid adding conflicting descriptions
346
+ if "sunset" in user_prompt_lower or "golden" in user_prompt_lower:
347
+ lighting = "" # User already specified lighting
348
+ if "dark" in user_prompt_lower or "night" in user_prompt_lower:
349
+ lighting = lighting.replace("bright", "").replace("daylight", "")
350
+
351
+ # === Combine enhanced prompt ===
352
+ fragments = [user_prompt]
353
+
354
+ if lighting:
355
+ fragments.append(lighting)
356
+ if atmosphere:
357
+ fragments.append(atmosphere)
358
+ fragments.append(quality_modifiers)
359
+
360
+ enhanced_prompt = ", ".join(filter(None, fragments))
361
+
362
+ logger.info(f"📝 Original prompt: {user_prompt[:50]}...")
363
+ logger.info(f"📝 Enhanced prompt: {enhanced_prompt[:80]}...")
364
+
365
+ return enhanced_prompt
366
+
367
+ except Exception as e:
368
+ logger.warning(f"⚠️ Prompt enhancement failed: {e}, using original prompt")
369
+ return user_prompt
370
+
371
+ def _prepare_image(self, image: Image.Image) -> Image.Image:
372
+ """Prepare image for processing - KEEP SAME"""
373
+ # Convert to RGB
374
+ if image.mode != 'RGB':
375
+ image = image.convert('RGB')
376
+
377
+ # Resize if too large
378
+ width, height = image.size
379
+ max_size = self.max_image_size
380
+
381
+ if width > max_size or height > max_size:
382
+ ratio = min(max_size/width, max_size/height)
383
+ new_width = int(width * ratio)
384
+ new_height = int(height * ratio)
385
+ image = image.resize((new_width, new_height), Image.LANCZOS)
386
+
387
+ # Ensure dimensions are multiple of 8
388
+ width, height = image.size
389
+ new_width = (width // 8) * 8
390
+ new_height = (height // 8) * 8
391
+
392
+ if new_width != width or new_height != height:
393
+ image = image.resize((new_width, new_height), Image.LANCZOS)
394
+
395
+ return image
396
+
397
+ def generate_background(
398
+ self,
399
+ prompt: str,
400
+ width: int,
401
+ height: int,
402
+ negative_prompt: str = "blurry, low quality, distorted",
403
+ num_inference_steps: int = 25,
404
+ guidance_scale: float = 7.5,
405
+ progress_callback: Optional[callable] = None
406
+ ) -> Image.Image:
407
+ """Generate complete background using standard text-to-image - KEEP SAME"""
408
+
409
+ if not self.is_initialized:
410
+ raise RuntimeError("Models not loaded. Call load_models() first.")
411
+
412
+ logger.info(f"🎨 Generating background: {prompt[:50]}...")
413
+
414
+ try:
415
+ with torch.inference_mode():
416
+ if progress_callback:
417
+ progress_callback("Generating background with SDXL...", 50)
418
+
419
+ # Standard text-to-image generation - KEEP SAME
420
+ result = self.pipeline(
421
+ prompt=prompt,
422
+ negative_prompt=negative_prompt,
423
+ width=width,
424
+ height=height,
425
+ num_inference_steps=num_inference_steps,
426
+ guidance_scale=guidance_scale,
427
+ generator=torch.Generator(device=self.device).manual_seed(42)
428
+ )
429
+
430
+ generated_image = result.images[0]
431
+
432
+ if progress_callback:
433
+ progress_callback("Background generated successfully!", 100)
434
+
435
+ logger.info("✅ Background generation completed!")
436
+ return generated_image
437
+
438
+ except torch.cuda.OutOfMemoryError:
439
+ logger.error("❌ GPU memory exhausted")
440
+ self._ultra_memory_cleanup()
441
+ raise RuntimeError("GPU memory insufficient")
442
+
443
+ except Exception as e:
444
+ logger.error(f"❌ Background generation failed: {e}")
445
+ raise RuntimeError(f"Generation failed: {str(e)}")
446
+
447
+ def generate_and_combine(
448
+ self,
449
+ original_image: Image.Image,
450
+ prompt: str,
451
+ combination_mode: str = "center",
452
+ focus_mode: str = "person",
453
+ negative_prompt: str = "blurry, low quality, distorted",
454
+ num_inference_steps: int = 25,
455
+ guidance_scale: float = 7.5,
456
+ progress_callback: Optional[callable] = None,
457
+ enable_prompt_enhancement: bool = True
458
+ ) -> Dict[str, Any]:
459
+ """
460
+ Generate background and combine with foreground using advanced blending.
461
+
462
+ Args:
463
+ original_image: Foreground image
464
+ prompt: User's background description
465
+ combination_mode: How to position foreground ("center", "left_half", "right_half", "full")
466
+ focus_mode: Focus type ("person" for tight crop, "scene" for wider context)
467
+ negative_prompt: What to avoid in generation
468
+ num_inference_steps: SDXL inference steps
469
+ guidance_scale: Classifier-free guidance scale
470
+ progress_callback: Progress reporting callback
471
+ enable_prompt_enhancement: Whether to use smart prompt enhancement
472
+
473
+ Returns:
474
+ Dictionary containing results and metadata
475
+ """
476
+
477
+ if not self.is_initialized:
478
+ raise RuntimeError("Models not loaded. Call load_models() first.")
479
+
480
+ logger.info(f"🎨 Starting generation and combination with advanced features...")
481
+
482
+ try:
483
+ # Enhanced memory management
484
+ if self.generation_count % self.cleanup_frequency == 0:
485
+ self._ultra_memory_cleanup()
486
+
487
+ if progress_callback:
488
+ progress_callback("Analyzing uploaded image...", 5)
489
+
490
+ # Analyze original image
491
+ image_analysis = self.analyze_image_with_clip(original_image)
492
+
493
+ if progress_callback:
494
+ progress_callback("Preparing images...", 10)
495
+
496
+ # Prepare original image
497
+ processed_original = self._prepare_image(original_image)
498
+ target_width, target_height = processed_original.size
499
+
500
+ if progress_callback:
501
+ progress_callback("Optimizing prompt...", 15)
502
+
503
+ # Smart prompt enhancement
504
+ if enable_prompt_enhancement:
505
+ enhanced_prompt = self.enhance_prompt(prompt, processed_original)
506
+ else:
507
+ enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic, beautiful scenery"
508
+
509
+ enhanced_negative = f"{negative_prompt}, people, characters, cartoons, logos"
510
+
511
+ if progress_callback:
512
+ progress_callback("Generating complete background scene...", 25)
513
+
514
+ def bg_progress(msg, pct):
515
+ if progress_callback:
516
+ progress_callback(f"Background: {msg}", 25 + (pct/100) * 50)
517
+
518
+ generated_background = self.generate_background(
519
+ prompt=enhanced_prompt,
520
+ width=target_width,
521
+ height=target_height,
522
+ negative_prompt=enhanced_negative,
523
+ num_inference_steps=num_inference_steps,
524
+ guidance_scale=guidance_scale,
525
+ progress_callback=bg_progress
526
+ )
527
+
528
+ if progress_callback:
529
+ progress_callback("Creating intelligent mask for person detection...", 80)
530
+
531
+ # Use intelligent mask generation with enhanced logging
532
+ logger.info("🎭 Starting intelligent mask generation...")
533
+ combination_mask = self.mask_generator.create_gradient_based_mask(
534
+ processed_original,
535
+ combination_mode,
536
+ focus_mode
537
+ )
538
+
539
+ # Log mask quality for debugging
540
+ try:
541
+ mask_array = np.array(combination_mask)
542
+ logger.info(f"📊 Generated mask stats - Mean: {mask_array.mean():.1f}, Non-zero pixels: {np.count_nonzero(mask_array)}")
543
+ except Exception as mask_debug_error:
544
+ logger.warning(f"⚠️ Mask debug logging failed: {mask_debug_error}")
545
+
546
+ if progress_callback:
547
+ progress_callback("Advanced image blending...", 90)
548
+
549
+ # Use advanced image blending with logging
550
+ logger.info("🖌️ Starting advanced image blending...")
551
+ combined_image = self.image_blender.simple_blend_images(
552
+ processed_original,
553
+ generated_background,
554
+ combination_mask
555
+ )
556
+ logger.info("✅ Image blending completed successfully")
557
+
558
+ if progress_callback:
559
+ progress_callback("Creating debug images...", 95)
560
+
561
+ # Generate debug images
562
+ debug_images = self.image_blender.create_debug_images(
563
+ processed_original,
564
+ generated_background,
565
+ combination_mask,
566
+ combined_image
567
+ )
568
+
569
+ # Memory cleanup after generation
570
+ self._ultra_memory_cleanup()
571
+
572
+ # Update generation count
573
+ self.generation_count += 1
574
+
575
+ if progress_callback:
576
+ progress_callback("Generation complete!", 100)
577
+
578
+ logger.info("✅ Complete generation and combination with fixed blending successful!")
579
+
580
+ return {
581
+ "combined_image": combined_image,
582
+ "generated_scene": generated_background,
583
+ "original_image": processed_original,
584
+ "combination_mask": combination_mask,
585
+ "debug_mask_gray": debug_images["mask_gray"],
586
+ "debug_alpha_heatmap": debug_images["alpha_heatmap"],
587
+ "image_analysis": image_analysis,
588
+ "enhanced_prompt": enhanced_prompt,
589
+ "original_prompt": prompt,
590
+ "success": True,
591
+ "generation_count": self.generation_count
592
+ }
593
+
594
+ except Exception as e:
595
+ import traceback
596
+ error_traceback = traceback.format_exc()
597
+ logger.error(f"❌ Generation and combination failed: {str(e)}")
598
+ logger.error(f"📍 Full traceback:\n{error_traceback}")
599
+ print(f"❌ DETAILED ERROR in scene_weaver_core.generate_and_combine:")
600
+ print(f"Error: {str(e)}")
601
+ print(f"Traceback:\n{error_traceback}")
602
+ self._ultra_memory_cleanup() # Cleanup on error too
603
+ return {
604
+ "success": False,
605
+ "error": f"Failed: {str(e)}"
606
+ }
607
+
608
+ def generate_diversity_variants(
609
+ self,
610
+ original_image: Image.Image,
611
+ prompt: str,
612
+ selected_styles: Optional[List[str]] = None,
613
+ combination_mode: str = "center",
614
+ focus_mode: str = "person",
615
+ negative_prompt: str = "blurry, low quality, distorted",
616
+ progress_callback: Optional[callable] = None
617
+ ) -> Dict[str, Any]:
618
+ """
619
+ Generate multiple style variants of the background.
620
+ Uses reduced quality for faster preview generation.
621
+
622
+ Args:
623
+ original_image: Foreground image
624
+ prompt: Base background description
625
+ selected_styles: List of style keys to use (None = all styles)
626
+ combination_mode: Foreground positioning mode
627
+ focus_mode: Focus type for mask generation
628
+ negative_prompt: Base negative prompt
629
+ progress_callback: Progress callback function
630
+
631
+ Returns:
632
+ Dictionary containing variants and metadata
633
+ """
634
+ if not self.is_initialized:
635
+ raise RuntimeError("Models not loaded. Call load_models() first.")
636
+
637
+ logger.info("🎨 Starting diversity generation mode...")
638
+
639
+ # Determine which styles to generate
640
+ styles_to_generate = selected_styles or list(self.STYLE_PRESETS.keys())
641
+ num_styles = len(styles_to_generate)
642
+
643
+ results = {
644
+ "variants": [],
645
+ "success": True,
646
+ "num_variants": 0
647
+ }
648
+
649
+ try:
650
+ # Pre-process image once
651
+ processed_original = self._prepare_image(original_image)
652
+ target_width, target_height = processed_original.size
653
+
654
+ # Reduce resolution for faster generation
655
+ preview_size = min(768, max(target_width, target_height))
656
+ scale = preview_size / max(target_width, target_height)
657
+ preview_width = int(target_width * scale) // 8 * 8
658
+ preview_height = int(target_height * scale) // 8 * 8
659
+
660
+ # Generate mask once (reusable for all variants)
661
+ if progress_callback:
662
+ progress_callback("Creating foreground mask...", 5)
663
+
664
+ combination_mask = self.mask_generator.create_gradient_based_mask(
665
+ processed_original, combination_mode, focus_mode
666
+ )
667
+
668
+ # Resize mask for preview resolution
669
+ preview_mask = combination_mask.resize((preview_width, preview_height), Image.LANCZOS)
670
+ preview_original = processed_original.resize((preview_width, preview_height), Image.LANCZOS)
671
+
672
+ # Generate each style variant
673
+ for idx, style_key in enumerate(styles_to_generate):
674
+ if style_key not in self.STYLE_PRESETS:
675
+ logger.warning(f"⚠️ Unknown style: {style_key}, skipping")
676
+ continue
677
+
678
+ style = self.STYLE_PRESETS[style_key]
679
+ style_name = style["name"]
680
+
681
+ if progress_callback:
682
+ base_pct = 10 + (idx / num_styles) * 80
683
+ progress_callback(f"Generating {style_name} variant...", int(base_pct))
684
+
685
+ logger.info(f"🎨 Generating variant: {style_name}")
686
+
687
+ try:
688
+ # Build style-specific prompt
689
+ styled_prompt = f"{prompt}, {style['modifier']}, high quality, detailed"
690
+ styled_negative = f"{negative_prompt}, {style['negative_extra']}, people, characters"
691
+
692
+ # Generate background with reduced steps for speed
693
+ background = self.generate_background(
694
+ prompt=styled_prompt,
695
+ width=preview_width,
696
+ height=preview_height,
697
+ negative_prompt=styled_negative,
698
+ num_inference_steps=15, # Reduced for speed
699
+ guidance_scale=style["guidance_scale"]
700
+ )
701
+
702
+ # Blend images
703
+ combined = self.image_blender.simple_blend_images(
704
+ preview_original,
705
+ background,
706
+ preview_mask,
707
+ use_multi_scale=False # Skip for speed
708
+ )
709
+
710
+ results["variants"].append({
711
+ "style_key": style_key,
712
+ "style_name": style_name,
713
+ "combined_image": combined,
714
+ "background": background,
715
+ "prompt_used": styled_prompt
716
+ })
717
+
718
+ # Memory cleanup between variants
719
+ self._ultra_memory_cleanup()
720
+
721
+ except Exception as variant_error:
722
+ logger.error(f"❌ Failed to generate {style_name} variant: {variant_error}")
723
+ continue
724
+
725
+ results["num_variants"] = len(results["variants"])
726
+
727
+ if progress_callback:
728
+ progress_callback("Diversity generation complete!", 100)
729
+
730
+ logger.info(f"✅ Generated {results['num_variants']} style variants")
731
+ return results
732
+
733
+ except Exception as e:
734
+ logger.error(f"❌ Diversity generation failed: {e}")
735
+ self._ultra_memory_cleanup()
736
+ return {
737
+ "variants": [],
738
+ "success": False,
739
+ "error": str(e),
740
+ "num_variants": 0
741
+ }
742
+
743
+ def regenerate_high_quality(
744
+ self,
745
+ original_image: Image.Image,
746
+ prompt: str,
747
+ style_key: str,
748
+ combination_mode: str = "center",
749
+ focus_mode: str = "person",
750
+ negative_prompt: str = "blurry, low quality, distorted",
751
+ progress_callback: Optional[callable] = None
752
+ ) -> Dict[str, Any]:
753
+ """
754
+ Regenerate a specific style at full quality.
755
+
756
+ Args:
757
+ original_image: Original foreground image
758
+ prompt: Base prompt
759
+ style_key: Style preset key to use
760
+ combination_mode: Foreground positioning
761
+ focus_mode: Mask focus mode
762
+ negative_prompt: Base negative prompt
763
+ progress_callback: Progress callback
764
+
765
+ Returns:
766
+ Full quality result dictionary
767
+ """
768
+ if style_key not in self.STYLE_PRESETS:
769
+ return {"success": False, "error": f"Unknown style: {style_key}"}
770
+
771
+ style = self.STYLE_PRESETS[style_key]
772
+
773
+ # Build styled prompt
774
+ styled_prompt = f"{prompt}, {style['modifier']}"
775
+ styled_negative = f"{negative_prompt}, {style['negative_extra']}"
776
+
777
+ # Use full generate_and_combine with style parameters
778
+ return self.generate_and_combine(
779
+ original_image=original_image,
780
+ prompt=styled_prompt,
781
+ combination_mode=combination_mode,
782
+ focus_mode=focus_mode,
783
+ negative_prompt=styled_negative,
784
+ num_inference_steps=25, # Full quality
785
+ guidance_scale=style["guidance_scale"],
786
+ progress_callback=progress_callback,
787
+ enable_prompt_enhancement=True
788
+ )
789
+
790
+ def get_memory_status(self) -> Dict[str, Any]:
791
+ """Enhanced memory status reporting"""
792
+ status = {"device": self.device}
793
+
794
+ if torch.cuda.is_available():
795
+ allocated = torch.cuda.memory_allocated() / 1024**3
796
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
797
+ cached = torch.cuda.memory_reserved() / 1024**3
798
+
799
+ status.update({
800
+ "gpu_allocated_gb": round(allocated, 2),
801
+ "gpu_total_gb": round(total, 2),
802
+ "gpu_cached_gb": round(cached, 2),
803
+ "gpu_free_gb": round(total - allocated, 2),
804
+ "gpu_usage_percent": round((allocated / total) * 100, 1),
805
+ "generation_count": self.generation_count
806
+ })
807
+
808
+ return status
ui_manager.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+ import gradio as gr
9
+ import spaces
10
+
11
+ from scene_weaver_core import SceneWeaverCore
12
+ from css_styles import CSSStyles
13
+ from scene_templates import SceneTemplateManager
14
+
15
+ logger = logging.getLogger(__name__)
16
+ logger.setLevel(logging.INFO)
17
+
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s [%(name)s] %(levelname)s: %(message)s',
21
+ datefmt='%H:%M:%S'
22
+ )
23
+
24
+
25
+ class UIManager:
26
+ """Gradio UI with enhanced memory management and professional design"""
27
+
28
+ def __init__(self):
29
+ self.sceneweaver = SceneWeaverCore()
30
+ self.template_manager = SceneTemplateManager()
31
+ self.generation_history = []
32
+ self._preview_sensitivity = 0.5
33
+
34
+ def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]:
35
+ """
36
+ Apply a scene template to the prompt fields.
37
+
38
+ Args:
39
+ display_name: The display name from dropdown (e.g., "🏢 Modern Office")
40
+ current_negative: Current negative prompt value
41
+
42
+ Returns:
43
+ Tuple of (prompt, negative_prompt, guidance_scale)
44
+ """
45
+ if not display_name:
46
+ return "", current_negative, 7.5
47
+
48
+ # Convert display name to template key
49
+ template_key = self.template_manager.get_template_key_from_display(display_name)
50
+ if not template_key:
51
+ return "", current_negative, 7.5
52
+
53
+ template = self.template_manager.get_template(template_key)
54
+ if template:
55
+ prompt = template.prompt
56
+ negative = self.template_manager.get_negative_prompt_for_template(
57
+ template_key, current_negative
58
+ )
59
+ guidance = template.guidance_scale
60
+ return prompt, negative, guidance
61
+
62
+ return "", current_negative, 7.5
63
+
64
+ def quick_preview(
65
+ self,
66
+ uploaded_image: Optional[Image.Image],
67
+ sensitivity: float = 0.5
68
+ ) -> Optional[Image.Image]:
69
+ """
70
+ Generate quick foreground preview using lightweight traditional methods.
71
+
72
+ Args:
73
+ uploaded_image: Uploaded PIL Image
74
+ sensitivity: Detection sensitivity (0.0 - 1.0)
75
+
76
+ Returns:
77
+ Preview image with colored overlay or None
78
+ """
79
+ if uploaded_image is None:
80
+ return None
81
+
82
+ try:
83
+ logger.info(f"Generating quick preview (sensitivity={sensitivity:.2f})")
84
+
85
+ img_array = np.array(uploaded_image.convert('RGB'))
86
+ height, width = img_array.shape[:2]
87
+
88
+ max_preview_size = 512
89
+ if max(width, height) > max_preview_size:
90
+ scale = max_preview_size / max(width, height)
91
+ new_w = int(width * scale)
92
+ new_h = int(height * scale)
93
+ img_array = cv2.resize(img_array, (new_w, new_h), interpolation=cv2.INTER_AREA)
94
+ height, width = new_h, new_w
95
+
96
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
97
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
98
+
99
+ low_threshold = int(30 + (1 - sensitivity) * 50)
100
+ high_threshold = int(100 + (1 - sensitivity) * 100)
101
+ edges = cv2.Canny(blurred, low_threshold, high_threshold)
102
+
103
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
104
+ dilated = cv2.dilate(edges, kernel, iterations=2)
105
+
106
+ contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
107
+
108
+ mask = np.zeros((height, width), dtype=np.uint8)
109
+
110
+ if contours:
111
+ sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True)
112
+ min_area = (width * height) * 0.01 * (1 - sensitivity)
113
+ for contour in sorted_contours:
114
+ if cv2.contourArea(contour) > min_area:
115
+ cv2.fillPoly(mask, [contour], 255)
116
+
117
+ kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11))
118
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close)
119
+
120
+ overlay = img_array.copy().astype(np.float32)
121
+
122
+ fg_mask = mask > 127
123
+ overlay[fg_mask] = overlay[fg_mask] * 0.5 + np.array([0, 255, 0]) * 0.5
124
+
125
+ bg_mask = mask <= 127
126
+ overlay[bg_mask] = overlay[bg_mask] * 0.5 + np.array([255, 0, 0]) * 0.5
127
+
128
+ overlay = np.clip(overlay, 0, 255).astype(np.uint8)
129
+
130
+ original_size = uploaded_image.size
131
+ preview_image = Image.fromarray(overlay)
132
+ if preview_image.size != original_size:
133
+ preview_image = preview_image.resize(original_size, Image.LANCZOS)
134
+
135
+ logger.info("Quick preview generated successfully")
136
+ return preview_image
137
+
138
+ except Exception as e:
139
+ logger.error(f"Quick preview failed: {e}")
140
+ return None
141
+
142
+ def _save_result(self, combined_image: Image.Image, prompt: str):
143
+ """Save result with memory-conscious history management"""
144
+ if not combined_image:
145
+ return
146
+
147
+ output_dir = Path("outputs")
148
+ output_dir.mkdir(exist_ok=True)
149
+
150
+ combined_image.save(output_dir / "latest_combined.png")
151
+
152
+ self.generation_history.append({
153
+ "prompt": prompt,
154
+ "timestamp": time.time()
155
+ })
156
+
157
+ max_history = self.sceneweaver.max_history
158
+ if len(self.generation_history) > max_history:
159
+ self.generation_history = self.generation_history[-max_history:]
160
+
161
+ @spaces.GPU(duration=120)
162
+ def generate_handler(
163
+ self,
164
+ uploaded_image: Optional[Image.Image],
165
+ prompt: str,
166
+ combination_mode: str,
167
+ focus_mode: str,
168
+ negative_prompt: str,
169
+ steps: int,
170
+ guidance: float,
171
+ progress=gr.Progress()
172
+ ):
173
+ """Enhanced generation handler with memory management and ZeroGPU support"""
174
+
175
+ if uploaded_image is None:
176
+ return None, None, None, "Please upload an image to get started!", gr.update(visible=False)
177
+
178
+ if not prompt.strip():
179
+ return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False)
180
+
181
+ try:
182
+ if not self.sceneweaver.is_initialized:
183
+ progress(0.05, desc="Loading AI models (first time may take 2-3 minutes)...")
184
+
185
+ def init_progress(msg, pct):
186
+ if pct < 30:
187
+ desc = "Loading image analysis models..."
188
+ elif pct < 60:
189
+ desc = "Loading Stable Diffusion XL..."
190
+ elif pct < 90:
191
+ desc = "Applying memory optimizations..."
192
+ else:
193
+ desc = "Almost ready..."
194
+ progress(0.05 + (pct/100) * 0.2, desc=desc)
195
+
196
+ self.sceneweaver.load_models(progress_callback=init_progress)
197
+
198
+ def gen_progress(msg, pct):
199
+ if pct < 20:
200
+ desc = "Analyzing your image..."
201
+ elif pct < 50:
202
+ desc = "Generating background scene..."
203
+ elif pct < 80:
204
+ desc = "Blending foreground and background..."
205
+ elif pct < 95:
206
+ desc = "Applying final touches..."
207
+ else:
208
+ desc = "Complete!"
209
+ progress(0.25 + (pct/100) * 0.75, desc=desc)
210
+
211
+ result = self.sceneweaver.generate_and_combine(
212
+ original_image=uploaded_image,
213
+ prompt=prompt,
214
+ combination_mode=combination_mode,
215
+ focus_mode=focus_mode,
216
+ negative_prompt=negative_prompt,
217
+ num_inference_steps=int(steps),
218
+ guidance_scale=float(guidance),
219
+ progress_callback=gen_progress
220
+ )
221
+
222
+ if result["success"]:
223
+ combined = result["combined_image"]
224
+ generated = result["generated_scene"]
225
+ original = result["original_image"]
226
+
227
+ self._save_result(combined, prompt)
228
+
229
+ status_msg = "Image created successfully!"
230
+
231
+ return combined, generated, original, status_msg, gr.update(visible=True)
232
+ else:
233
+ error_msg = result.get("error", "Something went wrong")
234
+ return None, None, None, f"Error: {error_msg}", gr.update(visible=False)
235
+
236
+ except Exception as e:
237
+ import traceback
238
+ error_traceback = traceback.format_exc()
239
+ logger.error(f"Generation handler error: {str(e)}")
240
+ logger.error(f"Traceback:\n{error_traceback}")
241
+ return None, None, None, f"Error: {str(e)}", gr.update(visible=False)
242
+
243
+ def create_interface(self):
244
+ """Create professional user interface"""
245
+
246
+ css = CSSStyles.get_main_css()
247
+
248
+ with gr.Blocks(
249
+ css=css,
250
+ title="SceneWeaver - AI Background Generator",
251
+ theme=gr.themes.Soft()
252
+ ) as interface:
253
+
254
+ # Header
255
+ gr.HTML("""
256
+ <div class="main-header">
257
+ <h1 class="main-title">
258
+ <span class="title-emoji">🎨</span>
259
+ SceneWeaver
260
+ </h1>
261
+ <p class="main-subtitle">AI-powered background generation with professional edge processing</p>
262
+ </div>
263
+ """)
264
+
265
+ with gr.Row():
266
+ # Left Column - Input controls
267
+ with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]):
268
+ gr.HTML("""
269
+ <div class="card-content">
270
+ <h3 class="card-title">
271
+ <span class="section-emoji">📸</span>
272
+ Upload & Generate
273
+ </h3>
274
+ </div>
275
+ """)
276
+
277
+ uploaded_image = gr.Image(
278
+ label="Upload Your Image",
279
+ type="pil",
280
+ height=280,
281
+ elem_classes=["input-field"]
282
+ )
283
+
284
+ # Scene Template Selector
285
+ with gr.Accordion("Scene Templates", open=False):
286
+ template_dropdown = gr.Dropdown(
287
+ label="Select a Scene",
288
+ choices=[""] + self.template_manager.get_template_choices_sorted(),
289
+ value="",
290
+ info="24 curated scenes sorted A-Z",
291
+ elem_classes=["template-dropdown"]
292
+ )
293
+
294
+ prompt_input = gr.Textbox(
295
+ label="Background Scene Description",
296
+ placeholder="Select a template above or describe your own scene...",
297
+ lines=3,
298
+ elem_classes=["input-field"]
299
+ )
300
+
301
+ combination_mode = gr.Dropdown(
302
+ label="Composition Mode",
303
+ choices=["center", "left_half", "right_half", "full"],
304
+ value="center",
305
+ info="center=Smart Center | left_half=Left Half | right_half=Right Half | full=Full Image",
306
+ elem_classes=["input-field"]
307
+ )
308
+
309
+ focus_mode = gr.Dropdown(
310
+ label="Focus Mode",
311
+ choices=["person", "scene"],
312
+ value="person",
313
+ info="person=Tight Crop | scene=Include Surrounding Objects",
314
+ elem_classes=["input-field"]
315
+ )
316
+
317
+ with gr.Accordion("Advanced Options", open=False):
318
+ negative_prompt = gr.Textbox(
319
+ label="Negative Prompt",
320
+ value="blurry, low quality, distorted, people, characters",
321
+ lines=2,
322
+ elem_classes=["input-field"]
323
+ )
324
+
325
+ steps_slider = gr.Slider(
326
+ label="Quality Steps",
327
+ minimum=15,
328
+ maximum=50,
329
+ value=25,
330
+ step=5,
331
+ elem_classes=["input-field"]
332
+ )
333
+
334
+ guidance_slider = gr.Slider(
335
+ label="Guidance Scale",
336
+ minimum=5.0,
337
+ maximum=15.0,
338
+ value=7.5,
339
+ step=0.5,
340
+ elem_classes=["input-field"]
341
+ )
342
+
343
+ generate_btn = gr.Button(
344
+ "Generate Background",
345
+ variant="primary",
346
+ size="lg",
347
+ elem_classes=["primary-button"]
348
+ )
349
+
350
+ # Right Column - Results display
351
+ with gr.Column(scale=2, elem_classes=["feature-card"], elem_id="results-gallery-centered"):
352
+ gr.HTML("""
353
+ <div class="card-content">
354
+ <h3 class="card-title">
355
+ <span class="section-emoji">🎭</span>
356
+ Results Gallery
357
+ </h3>
358
+ </div>
359
+ """)
360
+
361
+ # Loading notice
362
+ gr.HTML("""
363
+ <div class="loading-notice">
364
+ <span class="loading-notice-icon">⏱️</span>
365
+ <span class="loading-notice-text">
366
+ <strong>First-time users:</strong> Initial model loading takes 1-2 minutes.
367
+ Subsequent generations are much faster (~30s).
368
+ </span>
369
+ </div>
370
+ """)
371
+
372
+ # Quick start guide
373
+ gr.HTML("""
374
+ <details class="user-guidance-panel">
375
+ <summary class="guidance-summary">
376
+ <span class="emoji-enhanced">💡</span>
377
+ Quick Start Guide
378
+ </summary>
379
+ <div class="guidance-content">
380
+ <p><strong>Step 1:</strong> Upload any image with a clear subject</p>
381
+ <p><strong>Step 2:</strong> Describe or Choose your desired background scene</p>
382
+ <p><strong>Step 3:</strong> Choose composition mode (center works best)</p>
383
+ <p><strong>Step 4:</strong> Click Generate and wait for the magic!</p>
384
+ <p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p>
385
+ </div>
386
+ </details>
387
+ """)
388
+
389
+ with gr.Tabs():
390
+ with gr.TabItem("Final Result"):
391
+ combined_output = gr.Image(
392
+ label="Your Generated Image",
393
+ elem_classes=["result-gallery"],
394
+ show_label=False
395
+ )
396
+ with gr.TabItem("Background"):
397
+ generated_output = gr.Image(
398
+ label="Generated Background",
399
+ elem_classes=["result-gallery"],
400
+ show_label=False
401
+ )
402
+ with gr.TabItem("Original"):
403
+ original_output = gr.Image(
404
+ label="Processed Original",
405
+ elem_classes=["result-gallery"],
406
+ show_label=False
407
+ )
408
+
409
+ status_output = gr.Textbox(
410
+ label="Status",
411
+ value="Ready to create! Upload an image and describe your vision.",
412
+ interactive=False,
413
+ elem_classes=["status-panel", "status-ready"]
414
+ )
415
+
416
+ with gr.Row():
417
+ download_btn = gr.DownloadButton(
418
+ "Download Result",
419
+ value=None,
420
+ visible=False,
421
+ elem_classes=["secondary-button"]
422
+ )
423
+ clear_btn = gr.Button(
424
+ "Clear All",
425
+ elem_classes=["secondary-button"]
426
+ )
427
+ memory_btn = gr.Button(
428
+ "Clean Memory",
429
+ elem_classes=["secondary-button"]
430
+ )
431
+
432
+ # Footer with tech credits
433
+ gr.HTML("""
434
+ <div class="app-footer">
435
+ <div class="footer-powered">
436
+ <p class="footer-powered-title">Powered By</p>
437
+ <div class="footer-tech-grid">
438
+ <span class="footer-tech-item">Stable Diffusion XL</span>
439
+ <span class="footer-tech-item">OpenCLIP</span>
440
+ <span class="footer-tech-item">BiRefNet</span>
441
+ <span class="footer-tech-item">rembg</span>
442
+ <span class="footer-tech-item">PyTorch</span>
443
+ <span class="footer-tech-item">Gradio</span>
444
+ </div>
445
+ </div>
446
+ <div class="footer-divider"></div>
447
+ <p class="footer-copyright">
448
+ SceneWeaver &copy; 2025 &nbsp;|&nbsp;
449
+ Built with <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">SDXL</a>
450
+ and <a href="https://github.com/mlfoundations/open_clip" target="_blank">OpenCLIP</a>
451
+ </p>
452
+ </div>
453
+ """)
454
+
455
+ # Event handlers
456
+ # Template selection handler
457
+ template_dropdown.change(
458
+ fn=self.apply_template,
459
+ inputs=[template_dropdown, negative_prompt],
460
+ outputs=[prompt_input, negative_prompt, guidance_slider]
461
+ )
462
+
463
+ generate_btn.click(
464
+ fn=self.generate_handler,
465
+ inputs=[
466
+ uploaded_image,
467
+ prompt_input,
468
+ combination_mode,
469
+ focus_mode,
470
+ negative_prompt,
471
+ steps_slider,
472
+ guidance_slider
473
+ ],
474
+ outputs=[
475
+ combined_output,
476
+ generated_output,
477
+ original_output,
478
+ status_output,
479
+ download_btn
480
+ ]
481
+ )
482
+
483
+ clear_btn.click(
484
+ fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)),
485
+ outputs=[combined_output, generated_output, original_output, status_output, download_btn]
486
+ )
487
+
488
+ memory_btn.click(
489
+ fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!",
490
+ outputs=[status_output]
491
+ )
492
+
493
+ combined_output.change(
494
+ fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False),
495
+ inputs=[combined_output],
496
+ outputs=[download_btn]
497
+ )
498
+
499
+ return interface
500
+
501
+ def launch(self, share: bool = True, debug: bool = False):
502
+ """Launch the UI interface"""
503
+ interface = self.create_interface()
504
+
505
+ return interface.launch(
506
+ share=share,
507
+ debug=debug,
508
+ show_error=True,
509
+ height=800,
510
+ favicon_path=None,
511
+ ssl_verify=False,
512
+ quiet=False
513
+ )