Spaces:
Running
Running
| import glob | |
| import gradio as gr | |
| import matplotlib | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import tempfile | |
| from gradio_imageslider import ImageSlider | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import open3d as o3d | |
| from depth_anything_v2.dpt import DepthAnythingV2 | |
| import os | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| # Classification imports | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import google.generativeai as genai | |
| import gdown | |
| import spaces | |
| import cv2 | |
| # Import actual segmentation model components | |
| from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling | |
| from utils.learning.metrics import dice_coef, precision, recall | |
| from utils.io.data import normalize | |
| # --- Classification Model Setup --- | |
| # Load classification model and processor | |
| classification_processor = AutoImageProcessor.from_pretrained("Hemg/Wound-classification") | |
| classification_model = AutoModelForImageClassification.from_pretrained("Hemg/Wound-classification") | |
| # Configure Gemini AI | |
| try: | |
| # Try to get API key from Hugging Face secrets | |
| gemini_api_key = os.getenv("GOOGLE_API_KEY") | |
| if not gemini_api_key: | |
| raise ValueError("GEMINI_API_KEY not found in environment variables") | |
| genai.configure(api_key=gemini_api_key) | |
| gemini_model = genai.GenerativeModel("gemini-2.5-pro") | |
| print("β Gemini AI configured successfully with API key from secrets") | |
| except Exception as e: | |
| print(f"β Error configuring Gemini AI: {e}") | |
| print("Please make sure GEMINI_API_KEY is set in your Hugging Face Space secrets") | |
| gemini_model = None | |
| # --- Classification Functions --- | |
| def analyze_wound_with_gemini(image, predicted_label): | |
| """ | |
| Analyze wound image using Gemini AI with classification context | |
| Args: | |
| image: PIL Image | |
| predicted_label: The predicted wound type from classification model | |
| Returns: | |
| str: Gemini AI analysis | |
| """ | |
| if image is None: | |
| return "No image provided for analysis." | |
| if gemini_model is None: | |
| return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets." | |
| try: | |
| # Ensure image is in RGB format | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Create prompt that includes the classification result | |
| prompt = f"""You are assisting in a medical education and research task. | |
| Based on the wound classification model, this image has been identified as: {predicted_label} | |
| Please provide an educational analysis of this wound image focusing on: | |
| 1. Visible characteristics of the wound (size, color, texture, edges, surrounding tissue) | |
| 2. Educational explanation about this type of wound based on the classification: {predicted_label} | |
| 3. General wound healing stages if applicable | |
| 4. Key features that are typically associated with this wound type | |
| Important guidelines: | |
| - This is for educational and research purposes only | |
| - Do not provide medical advice or diagnosis | |
| - Keep the analysis objective and educational | |
| - Focus on visible features and general wound characteristics | |
| - Do not recommend treatments or medical interventions | |
| Please provide a comprehensive educational analysis.""" | |
| response = gemini_model.generate_content([prompt, image]) | |
| return response.text | |
| except Exception as e: | |
| return f"Error analyzing image with Gemini: {str(e)}" | |
| def analyze_wound_depth_with_gemini(image, depth_map, depth_stats): | |
| """ | |
| Analyze wound depth and severity using Gemini AI with depth analysis context | |
| Args: | |
| image: Original wound image (PIL Image or numpy array) | |
| depth_map: Depth map (numpy array) | |
| depth_stats: Dictionary containing depth analysis statistics | |
| Returns: | |
| str: Gemini AI medical assessment based on depth analysis | |
| """ | |
| if image is None or depth_map is None: | |
| return "No image or depth map provided for analysis." | |
| if gemini_model is None: | |
| return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets." | |
| try: | |
| # Convert numpy array to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Ensure image is in RGB format | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Convert depth map to PIL Image for Gemini | |
| if isinstance(depth_map, np.ndarray): | |
| # Normalize depth map for visualization | |
| norm_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0 | |
| depth_image = Image.fromarray(norm_depth.astype(np.uint8)) | |
| else: | |
| depth_image = depth_map | |
| # Create detailed prompt with depth statistics | |
| prompt = f"""You are a medical AI assistant specializing in wound assessment. Analyze this wound using both the original image and depth map data. | |
| DEPTH ANALYSIS DATA PROVIDED: | |
| - Total Wound Area: {depth_stats['total_area_cm2']:.2f} cmΒ² | |
| - Mean Depth: {depth_stats['mean_depth_mm']:.1f} mm | |
| - Maximum Depth: {depth_stats['max_depth_mm']:.1f} mm | |
| - Depth Standard Deviation: {depth_stats['depth_std_mm']:.1f} mm | |
| - Wound Volume: {depth_stats['wound_volume_cm3']:.2f} cmΒ³ | |
| - Deep Tissue Involvement: {depth_stats['deep_ratio']*100:.1f}% | |
| - Analysis Quality: {depth_stats['analysis_quality']} | |
| - Depth Consistency: {depth_stats['depth_consistency']} | |
| TISSUE DEPTH DISTRIBUTION: | |
| - Superficial Areas (0-2mm): {depth_stats['superficial_area_cm2']:.2f} cmΒ² | |
| - Partial Thickness (2-4mm): {depth_stats['partial_thickness_area_cm2']:.2f} cmΒ² | |
| - Full Thickness (4-6mm): {depth_stats['full_thickness_area_cm2']:.2f} cmΒ² | |
| - Deep Areas (>6mm): {depth_stats['deep_area_cm2']:.2f} cmΒ² | |
| STATISTICAL DEPTH ANALYSIS: | |
| - 25th Percentile Depth: {depth_stats['depth_percentiles']['25']:.1f} mm | |
| - Median Depth: {depth_stats['depth_percentiles']['50']:.1f} mm | |
| - 75th Percentile Depth: {depth_stats['depth_percentiles']['75']:.1f} mm | |
| Please provide a comprehensive medical assessment focusing on: | |
| 1. **WOUND CHARACTERISTICS ANALYSIS** | |
| - Visible wound features from the original image | |
| - Correlation between visual appearance and depth measurements | |
| - Tissue quality assessment based on color, texture, and depth data | |
| 2. **DEPTH-BASED SEVERITY ASSESSMENT** | |
| - Clinical significance of the measured depths | |
| - Tissue layer involvement based on depth measurements | |
| - Risk assessment based on deep tissue involvement percentage | |
| 3. **HEALING PROGNOSIS** | |
| - Expected healing timeline based on depth and area measurements | |
| - Factors that may affect healing based on depth distribution | |
| - Complexity assessment based on wound volume and depth variation | |
| 4. **CLINICAL CONSIDERATIONS** | |
| - Significance of depth consistency/inconsistency | |
| - Areas of particular concern based on depth analysis | |
| - Educational insights about this type of wound presentation | |
| 5. **MEASUREMENT INTERPRETATION** | |
| - Clinical relevance of the statistical depth measurements | |
| - What the depth distribution tells us about wound progression | |
| - Comparison to typical wound depth classifications | |
| IMPORTANT GUIDELINES: | |
| - This is for educational and research purposes only | |
| - Do not provide specific medical advice or treatment recommendations | |
| - Focus on objective analysis of the provided measurements | |
| - Correlate visual findings with quantitative depth data | |
| - Maintain educational and clinical terminology | |
| - Emphasize the relationship between depth measurements and clinical significance | |
| Provide a detailed, structured medical assessment that integrates both visual and quantitative depth analysis.""" | |
| # Send both images to Gemini for analysis | |
| response = gemini_model.generate_content([prompt, image, depth_image]) | |
| return response.text | |
| except Exception as e: | |
| return f"Error analyzing wound with Gemini AI: {str(e)}" | |
| def classify_wound(image): | |
| """ | |
| Classify wound type from uploaded image | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| dict: Classification results with confidence scores | |
| """ | |
| if image is None: | |
| return "Please upload an image" | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Ensure image is in RGB format | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| try: | |
| # Process the image | |
| inputs = classification_processor(images=image, return_tensors="pt") | |
| # Get model predictions | |
| with torch.no_grad(): | |
| outputs = classification_model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits[0], dim=-1) | |
| # Get the predicted class labels and confidence scores | |
| confidence_scores = predictions.numpy() | |
| # Create results dictionary | |
| results = {} | |
| for i, score in enumerate(confidence_scores): | |
| # Get class name from model config | |
| class_name = classification_model.config.id2label[i] if hasattr(classification_model.config, 'id2label') else f"Class {i}" | |
| results[class_name] = float(score) | |
| return results | |
| except Exception as e: | |
| return f"Error processing image: {str(e)}" | |
| def classify_and_analyze_wound(image): | |
| """ | |
| Combined function to classify wound and get Gemini analysis | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| tuple: (classification_results, gemini_analysis) | |
| """ | |
| if image is None: | |
| return "Please upload an image", "Please upload an image for analysis" | |
| # Get classification results | |
| classification_results = classify_wound(image) | |
| # Get the top predicted label for Gemini analysis | |
| if isinstance(classification_results, dict) and classification_results: | |
| # Get the label with highest confidence | |
| top_label = max(classification_results.items(), key=lambda x: x[1])[0] | |
| # Get Gemini analysis | |
| gemini_analysis = analyze_wound_with_gemini(image, top_label) | |
| else: | |
| top_label = "Unknown" | |
| gemini_analysis = "Unable to analyze due to classification error" | |
| return classification_results, gemini_analysis | |
| def format_gemini_analysis(analysis): | |
| """Format Gemini analysis as properly structured HTML""" | |
| if not analysis or "Error" in analysis: | |
| return f""" | |
| <div style=" | |
| background-color: #fee2e2; | |
| border-radius: 12px; | |
| padding: 16px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
| font-family: Arial, sans-serif; | |
| min-height: 300px; | |
| border-left: 4px solid #ef4444; | |
| "> | |
| <h4 style="color: #dc2626; margin-top: 0;">Analysis Error</h4> | |
| <p style="color: #991b1b;">{analysis}</p> | |
| </div> | |
| """ | |
| # Parse the markdown-style response and convert to HTML | |
| formatted_analysis = parse_markdown_to_html(analysis) | |
| return f""" | |
| <div style=" | |
| border-radius: 12px; | |
| padding: 25px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| min-height: 300px; | |
| border-left: 4px solid #d97706; | |
| max-height: 600px; | |
| overflow-y: auto; | |
| "> | |
| <h3 style="color: #d97706; margin-top: 0; margin-bottom: 20px; display: flex; align-items: center; gap: 8px;"> | |
| Initial Wound Analysis | |
| </h3> | |
| <div style="color: white; line-height: 1.7;"> | |
| {formatted_analysis} | |
| </div> | |
| </div> | |
| """ | |
| def format_gemini_depth_analysis(analysis): | |
| """Format Gemini depth analysis as properly structured HTML for medical assessment""" | |
| if not analysis or "Error" in analysis: | |
| return f""" | |
| <div style="color: #ffffff; line-height: 1.6;"> | |
| <div style="font-size: 16px; font-weight: bold; margin-bottom: 10px; color: #f44336;"> | |
| β AI Analysis Error | |
| </div> | |
| <div style="color: #cccccc;"> | |
| {analysis} | |
| </div> | |
| </div> | |
| """ | |
| # Parse the markdown-style response and convert to HTML | |
| formatted_analysis = parse_markdown_to_html(analysis) | |
| return f""" | |
| <div style="color: #ffffff; line-height: 1.6;"> | |
| <div style="font-size: 16px; font-weight: bold; margin-bottom: 15px; color: #4CAF50;"> | |
| π€ AI-Powered Medical Assessment | |
| </div> | |
| <div style="color: #cccccc; max-height: 400px; overflow-y: auto; padding-right: 10px;"> | |
| {formatted_analysis} | |
| </div> | |
| </div> | |
| """ | |
| def parse_markdown_to_html(text): | |
| """Convert markdown-style text to HTML""" | |
| import re | |
| # Replace markdown headers | |
| text = re.sub(r'^### \*\*(.*?)\*\*$', r'<h4 style="color: #d97706; margin: 20px 0 10px 0; font-weight: bold;">\1</h4>', text, flags=re.MULTILINE) | |
| text = re.sub(r'^#### \*\*(.*?)\*\*$', r'<h5 style="color: #f59e0b; margin: 15px 0 8px 0; font-weight: bold;">\1</h5>', text, flags=re.MULTILINE) | |
| text = re.sub(r'^### (.*?)$', r'<h4 style="color: #d97706; margin: 20px 0 10px 0; font-weight: bold;">\1</h4>', text, flags=re.MULTILINE) | |
| text = re.sub(r'^#### (.*?)$', r'<h5 style="color: #f59e0b; margin: 15px 0 8px 0; font-weight: bold;">\1</h5>', text, flags=re.MULTILINE) | |
| # Replace bold text | |
| text = re.sub(r'\*\*(.*?)\*\*', r'<strong style="color: #fbbf24;">\1</strong>', text) | |
| # Replace italic text | |
| text = re.sub(r'\*(.*?)\*', r'<em style="color: #fde68a;">\1</em>', text) | |
| # Replace bullet points | |
| text = re.sub(r'^\* (.*?)$', r'<li style="margin: 5px 0; color: white;">\1</li>', text, flags=re.MULTILINE) | |
| text = re.sub(r'^ \* (.*?)$', r'<li style="margin: 3px 0; margin-left: 20px; color: white;">\1</li>', text, flags=re.MULTILINE) | |
| # Wrap consecutive list items in ul tags | |
| text = re.sub(r'(<li.*?</li>(?:\s*<li.*?</li>)*)', r'<ul style="margin: 10px 0; padding-left: 20px;">\1</ul>', text, flags=re.DOTALL) | |
| # Replace numbered lists | |
| text = re.sub(r'^(\d+)\.\s+(.*?)$', r'<div style="margin: 8px 0; color: white;"><strong style="color: #d97706;">\1.</strong> \2</div>', text, flags=re.MULTILINE) | |
| # Convert paragraphs (double newlines) | |
| paragraphs = text.split('\n\n') | |
| formatted_paragraphs = [] | |
| for para in paragraphs: | |
| para = para.strip() | |
| if para: | |
| # Skip if it's already wrapped in HTML tags | |
| if not (para.startswith('<') or para.endswith('>')): | |
| para = f'<p style="margin: 12px 0; color: white; text-align: justify;">{para}</p>' | |
| formatted_paragraphs.append(para) | |
| return '\n'.join(formatted_paragraphs) | |
| def combined_analysis(image): | |
| """Combined function for UI that returns both outputs""" | |
| classification, gemini_analysis = classify_and_analyze_wound(image) | |
| formatted_analysis = format_gemini_analysis(gemini_analysis) | |
| return classification, formatted_analysis | |
| # Define path and file ID | |
| checkpoint_dir = "checkpoints" | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") | |
| gdrive_url = "https://drive.google.com/uc?id=1tfNiBCB2-yF1xtHU5KzCJOlwMs20v1_E" | |
| # Download if not already present | |
| if not os.path.exists(model_file): | |
| print("Downloading model from Google Drive...") | |
| gdown.download(gdrive_url, model_file, quiet=False) | |
| # --- TensorFlow: Check GPU Availability --- | |
| gpus = tf.config.list_physical_devices('GPU') | |
| if gpus: | |
| print("TensorFlow is using GPU") | |
| else: | |
| print("TensorFlow is using CPU") | |
| # --- Load Actual Wound Segmentation Model --- | |
| class WoundSegmentationModel: | |
| def __init__(self): | |
| self.input_dim_x = 224 | |
| self.input_dim_y = 224 | |
| self.model = None | |
| self.load_model() | |
| def load_model(self): | |
| """Load the trained wound segmentation model""" | |
| try: | |
| # Try to load the most recent model | |
| weight_file_name = '2025-08-07_16-25-27.hdf5' | |
| model_path = f'./training_history/{weight_file_name}' | |
| self.model = load_model(model_path, | |
| custom_objects={ | |
| 'recall': recall, | |
| 'precision': precision, | |
| 'dice_coef': dice_coef, | |
| 'relu6': relu6, | |
| 'DepthwiseConv2D': DepthwiseConv2D, | |
| 'BilinearUpsampling': BilinearUpsampling | |
| }) | |
| print(f"Segmentation model loaded successfully from {model_path}") | |
| except Exception as e: | |
| print(f"Error loading segmentation model: {e}") | |
| # Fallback to the older model | |
| try: | |
| weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' | |
| model_path = f'./training_history/{weight_file_name}' | |
| self.model = load_model(model_path, | |
| custom_objects={ | |
| 'recall': recall, | |
| 'precision': precision, | |
| 'dice_coef': dice_coef, | |
| 'relu6': relu6, | |
| 'DepthwiseConv2D': DepthwiseConv2D, | |
| 'BilinearUpsampling': BilinearUpsampling | |
| }) | |
| print(f"Segmentation model loaded successfully from {model_path}") | |
| except Exception as e2: | |
| print(f"Error loading fallback segmentation model: {e2}") | |
| self.model = None | |
| def preprocess_image(self, image): | |
| """Preprocess the uploaded image for model input""" | |
| if image is None: | |
| return None | |
| # Convert to RGB if needed | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| # Convert BGR to RGB if needed | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Resize to model input size | |
| image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) | |
| # Normalize the image | |
| image = image.astype(np.float32) / 255.0 | |
| # Add batch dimension | |
| image = np.expand_dims(image, axis=0) | |
| return image | |
| def postprocess_prediction(self, prediction): | |
| """Postprocess the model prediction""" | |
| # Remove batch dimension | |
| prediction = prediction[0] | |
| # Apply threshold to get binary mask | |
| threshold = 0.5 | |
| binary_mask = (prediction > threshold).astype(np.uint8) * 255 | |
| return binary_mask | |
| def segment_wound(self, input_image): | |
| """Main function to segment wound from uploaded image""" | |
| if self.model is None: | |
| return None, "Error: Segmentation model not loaded. Please check the model files." | |
| if input_image is None: | |
| return None, "Please upload an image." | |
| try: | |
| # Preprocess the image | |
| processed_image = self.preprocess_image(input_image) | |
| if processed_image is None: | |
| return None, "Error processing image." | |
| # Make prediction | |
| prediction = self.model.predict(processed_image, verbose=0) | |
| # Postprocess the prediction | |
| segmented_mask = self.postprocess_prediction(prediction) | |
| return segmented_mask, "Segmentation completed successfully!" | |
| except Exception as e: | |
| return None, f"Error during segmentation: {str(e)}" | |
| # Initialize the segmentation model | |
| segmentation_model = WoundSegmentationModel() | |
| # --- PyTorch: Set Device and Load Depth Model --- | |
| map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") | |
| print(f"Using PyTorch device: {map_device}") | |
| model_configs = { | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
| 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
| } | |
| encoder = 'vitl' | |
| depth_model = DepthAnythingV2(**model_configs[encoder]) | |
| state_dict = torch.load( | |
| f'checkpoints/depth_anything_v2_{encoder}.pth', | |
| map_location=map_device | |
| ) | |
| depth_model.load_state_dict(state_dict) | |
| depth_model = depth_model.to(map_device).eval() | |
| # --- Custom CSS for unified dark theme --- | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', sans-serif; | |
| background-color: #121212; | |
| color: #ffffff; | |
| padding: 20px; | |
| } | |
| .gr-button { | |
| background-color: #2c3e50; | |
| color: white; | |
| border-radius: 10px; | |
| } | |
| .gr-button:hover { | |
| background-color: #34495e; | |
| } | |
| .gr-html, .gr-html div { | |
| white-space: normal !important; | |
| overflow: visible !important; | |
| text-overflow: unset !important; | |
| word-break: break-word !important; | |
| } | |
| #img-display-container { | |
| max-height: 100vh; | |
| } | |
| #img-display-input { | |
| max-height: 80vh; | |
| } | |
| #img-display-output { | |
| max-height: 80vh; | |
| } | |
| #download { | |
| height: 62px; | |
| } | |
| h1 { | |
| text-align: center; | |
| font-size: 3rem; | |
| font-weight: bold; | |
| margin: 2rem 0; | |
| color: #ffffff; | |
| } | |
| h2 { | |
| color: #ffffff; | |
| text-align: center; | |
| margin: 1rem 0; | |
| } | |
| .gr-tabs { | |
| background-color: #1e1e1e; | |
| border-radius: 10px; | |
| padding: 10px; | |
| } | |
| .gr-tab-nav { | |
| background-color: #2c3e50; | |
| border-radius: 8px; | |
| } | |
| .gr-tab-nav button { | |
| color: #ffffff !important; | |
| } | |
| .gr-tab-nav button.selected { | |
| background-color: #34495e !important; | |
| } | |
| /* Card styling for consistent heights */ | |
| .wound-card { | |
| min-height: 200px !important; | |
| display: flex !important; | |
| flex-direction: column !important; | |
| justify-content: space-between !important; | |
| } | |
| .wound-card-content { | |
| flex-grow: 1 !important; | |
| display: flex !important; | |
| flex-direction: column !important; | |
| justify-content: center !important; | |
| } | |
| /* Loading animation */ | |
| .loading-spinner { | |
| display: inline-block; | |
| width: 20px; | |
| height: 20px; | |
| border: 3px solid #f3f3f3; | |
| border-top: 3px solid #3498db; | |
| border-radius: 50%; | |
| animation: spin 1s linear infinite; | |
| } | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| """ | |
| # --- Enhanced Wound Severity Estimation Functions --- | |
| def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): | |
| """ | |
| Enhanced depth analysis with proper calibration and medical standards | |
| Based on wound depth classification standards: | |
| - Superficial: 0-2mm (epidermis only) | |
| - Partial thickness: 2-4mm (epidermis + partial dermis) | |
| - Full thickness: 4-6mm (epidermis + full dermis) | |
| - Deep: >6mm (involving subcutaneous tissue) | |
| """ | |
| # Convert pixel spacing to mm | |
| pixel_spacing_mm = float(pixel_spacing_mm) | |
| # Calculate pixel area in cmΒ² | |
| pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 | |
| # Extract wound region (binary mask) | |
| wound_mask = (mask > 127).astype(np.uint8) | |
| # Apply morphological operations to clean the mask | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
| wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel) | |
| # Get depth values only for wound region | |
| wound_depths = depth_map[wound_mask > 0] | |
| if len(wound_depths) == 0: | |
| return { | |
| 'total_area_cm2': 0, | |
| 'superficial_area_cm2': 0, | |
| 'partial_thickness_area_cm2': 0, | |
| 'full_thickness_area_cm2': 0, | |
| 'deep_area_cm2': 0, | |
| 'mean_depth_mm': 0, | |
| 'max_depth_mm': 0, | |
| 'depth_std_mm': 0, | |
| 'deep_ratio': 0, | |
| 'wound_volume_cm3': 0, | |
| 'depth_percentiles': {'25': 0, '50': 0, '75': 0} | |
| } | |
| # Normalize depth relative to nearest point in wound area | |
| normalized_depth_map, nearest_point_coords, max_relative_depth = normalize_depth_relative_to_nearest_point(depth_map, wound_mask) | |
| # Calibrate the normalized depth map for more accurate measurements | |
| calibrated_depth_map = calibrate_depth_map(normalized_depth_map, reference_depth_mm=depth_calibration_mm) | |
| # Get calibrated depth values for wound region | |
| wound_depths_mm = calibrated_depth_map[wound_mask > 0] | |
| # Medical depth classification | |
| superficial_mask = wound_depths_mm < 2.0 | |
| partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0) | |
| full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0) | |
| deep_mask = wound_depths_mm >= 6.0 | |
| # Calculate areas | |
| total_pixels = np.sum(wound_mask > 0) | |
| total_area_cm2 = total_pixels * pixel_area_cm2 | |
| superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2 | |
| partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2 | |
| full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2 | |
| deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2 | |
| # Calculate depth statistics | |
| mean_depth_mm = np.mean(wound_depths_mm) | |
| max_depth_mm = np.max(wound_depths_mm) | |
| depth_std_mm = np.std(wound_depths_mm) | |
| # Calculate depth percentiles | |
| depth_percentiles = { | |
| '25': np.percentile(wound_depths_mm, 25), | |
| '50': np.percentile(wound_depths_mm, 50), | |
| '75': np.percentile(wound_depths_mm, 75) | |
| } | |
| # Calculate depth distribution statistics | |
| depth_distribution = { | |
| 'shallow_ratio': np.sum(wound_depths_mm < 2.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0, | |
| 'moderate_ratio': np.sum((wound_depths_mm >= 2.0) & (wound_depths_mm < 5.0)) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0, | |
| 'deep_ratio': np.sum(wound_depths_mm >= 5.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0 | |
| } | |
| # Calculate wound volume (approximate) | |
| # Volume = area * average depth | |
| wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0) | |
| # Deep tissue ratio | |
| deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0 | |
| # Calculate analysis quality metrics | |
| wound_pixel_count = len(wound_depths_mm) | |
| analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low" | |
| # Calculate depth consistency (lower std dev = more consistent) | |
| depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low" | |
| return { | |
| 'total_area_cm2': total_area_cm2, | |
| 'superficial_area_cm2': superficial_area_cm2, | |
| 'partial_thickness_area_cm2': partial_thickness_area_cm2, | |
| 'full_thickness_area_cm2': full_thickness_area_cm2, | |
| 'deep_area_cm2': deep_area_cm2, | |
| 'mean_depth_mm': mean_depth_mm, | |
| 'max_depth_mm': max_depth_mm, | |
| 'depth_std_mm': depth_std_mm, | |
| 'deep_ratio': deep_ratio, | |
| 'wound_volume_cm3': wound_volume_cm3, | |
| 'depth_percentiles': depth_percentiles, | |
| 'depth_distribution': depth_distribution, | |
| 'analysis_quality': analysis_quality, | |
| 'depth_consistency': depth_consistency, | |
| 'wound_pixel_count': wound_pixel_count, | |
| 'nearest_point_coords': nearest_point_coords, | |
| 'max_relative_depth': max_relative_depth, | |
| 'normalized_depth_map': normalized_depth_map | |
| } | |
| def classify_wound_severity_by_enhanced_metrics(depth_stats): | |
| """ | |
| Enhanced wound severity classification based on medical standards | |
| Uses multiple criteria: depth, area, volume, and tissue involvement | |
| """ | |
| if depth_stats['total_area_cm2'] == 0: | |
| return "Unknown" | |
| # Extract key metrics | |
| total_area = depth_stats['total_area_cm2'] | |
| deep_area = depth_stats['deep_area_cm2'] | |
| full_thickness_area = depth_stats['full_thickness_area_cm2'] | |
| mean_depth = depth_stats['mean_depth_mm'] | |
| max_depth = depth_stats['max_depth_mm'] | |
| wound_volume = depth_stats['wound_volume_cm3'] | |
| deep_ratio = depth_stats['deep_ratio'] | |
| # Medical severity classification criteria | |
| severity_score = 0 | |
| # Criterion 1: Maximum depth | |
| if max_depth >= 10.0: | |
| severity_score += 3 # Very severe | |
| elif max_depth >= 6.0: | |
| severity_score += 2 # Severe | |
| elif max_depth >= 4.0: | |
| severity_score += 1 # Moderate | |
| # Criterion 2: Mean depth | |
| if mean_depth >= 5.0: | |
| severity_score += 2 | |
| elif mean_depth >= 3.0: | |
| severity_score += 1 | |
| # Criterion 3: Deep tissue involvement ratio | |
| if deep_ratio >= 0.5: | |
| severity_score += 3 # More than 50% deep tissue | |
| elif deep_ratio >= 0.25: | |
| severity_score += 2 # 25-50% deep tissue | |
| elif deep_ratio >= 0.1: | |
| severity_score += 1 # 10-25% deep tissue | |
| # Criterion 4: Total wound area | |
| if total_area >= 10.0: | |
| severity_score += 2 # Large wound (>10 cmΒ²) | |
| elif total_area >= 5.0: | |
| severity_score += 1 # Medium wound (5-10 cmΒ²) | |
| # Criterion 5: Wound volume | |
| if wound_volume >= 5.0: | |
| severity_score += 2 # High volume | |
| elif wound_volume >= 2.0: | |
| severity_score += 1 # Medium volume | |
| # Determine severity based on total score | |
| if severity_score >= 8: | |
| return "Very Severe" | |
| elif severity_score >= 6: | |
| return "Severe" | |
| elif severity_score >= 4: | |
| return "Moderate" | |
| elif severity_score >= 2: | |
| return "Mild" | |
| else: | |
| return "Superficial" | |
| def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): | |
| """Enhanced wound severity analysis based on depth measurements""" | |
| if image is None or depth_map is None or wound_mask is None: | |
| return "β Please upload image, depth map, and wound mask." | |
| # Convert wound mask to grayscale if needed | |
| if len(wound_mask.shape) == 3: | |
| wound_mask = np.mean(wound_mask, axis=2) | |
| # Ensure depth map and mask have same dimensions | |
| if depth_map.shape[:2] != wound_mask.shape[:2]: | |
| # Resize mask to match depth map | |
| from PIL import Image | |
| mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) | |
| mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) | |
| wound_mask = np.array(mask_pil) | |
| # Compute enhanced statistics with relative depth normalization | |
| stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm) | |
| # Get severity based on enhanced metrics | |
| severity_level = classify_wound_severity_by_enhanced_metrics(stats) | |
| severity_description = get_enhanced_severity_description(severity_level) | |
| # Get Gemini AI analysis based on depth data | |
| gemini_analysis = analyze_wound_depth_with_gemini(image, depth_map, stats) | |
| # Format Gemini analysis for display | |
| formatted_gemini_analysis = format_gemini_depth_analysis(gemini_analysis) | |
| # Create depth analysis visualization | |
| depth_visualization = create_depth_analysis_visualization( | |
| stats['normalized_depth_map'], wound_mask, | |
| stats['nearest_point_coords'], stats['max_relative_depth'] | |
| ) | |
| # Enhanced severity color coding | |
| severity_color = { | |
| "Superficial": "#4CAF50", # Green | |
| "Mild": "#8BC34A", # Light Green | |
| "Moderate": "#FF9800", # Orange | |
| "Severe": "#F44336", # Red | |
| "Very Severe": "#9C27B0" # Purple | |
| }.get(severity_level, "#9E9E9E") # Gray for unknown | |
| # Create comprehensive medical report | |
| report = f""" | |
| <div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'> | |
| <div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'> | |
| π©Ή Enhanced Wound Severity Analysis | |
| </div> | |
| <div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px;'> | |
| <div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 15px; text-align: center;'> | |
| π Depth & Quality Analysis | |
| </div> | |
| <div style='color: #cccccc; line-height: 1.6; display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 20px;'> | |
| <div> | |
| <div style='font-size: 16px; font-weight: bold; color: #ff9800; margin-bottom: 8px;'>οΏ½ Basic Measurements</div> | |
| <div>οΏ½π <b>Mean Relative Depth:</b> {stats['mean_depth_mm']:.1f} mm</div> | |
| <div>π <b>Max Relative Depth:</b> {stats['max_depth_mm']:.1f} mm</div> | |
| <div>π <b>Depth Std Dev:</b> {stats['depth_std_mm']:.1f} mm</div> | |
| <div>π¦ <b>Wound Volume:</b> {stats['wound_volume_cm3']:.2f} cmΒ³</div> | |
| <div>π₯ <b>Deep Tissue Ratio:</b> {stats['deep_ratio']*100:.1f}%</div> | |
| </div> | |
| <div> | |
| <div style='font-size: 16px; font-weight: bold; color: #4CAF50; margin-bottom: 8px;'>π Statistical Analysis</div> | |
| <div>οΏ½ <b>25th Percentile:</b> {stats['depth_percentiles']['25']:.1f} mm</div> | |
| <div>π <b>Median (50th):</b> {stats['depth_percentiles']['50']:.1f} mm</div> | |
| <div>π <b>75th Percentile:</b> {stats['depth_percentiles']['75']:.1f} mm</div> | |
| <div>π <b>Shallow Areas:</b> {stats['depth_distribution']['shallow_ratio']*100:.1f}%</div> | |
| <div>π <b>Moderate Areas:</b> {stats['depth_distribution']['moderate_ratio']*100:.1f}%</div> | |
| </div> | |
| <div> | |
| <div style='font-size: 16px; font-weight: bold; color: #2196F3; margin-bottom: 8px;'>π Quality Metrics</div> | |
| <div>π <b>Analysis Quality:</b> {stats['analysis_quality']}</div> | |
| <div>π <b>Depth Consistency:</b> {stats['depth_consistency']}</div> | |
| <div>π <b>Data Points:</b> {stats['wound_pixel_count']:,}</div> | |
| <div>π <b>Deep Areas:</b> {stats['depth_distribution']['deep_ratio']*100:.1f}%</div> | |
| <div>π― <b>Reference Point:</b> Nearest to camera</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid {severity_color};'> | |
| <div style='font-size: 18px; font-weight: bold; color: {severity_color}; margin-bottom: 10px;'> | |
| π Medical Assessment Based on Depth Analysis | |
| </div> | |
| {formatted_gemini_analysis} | |
| </div> | |
| </div> | |
| """ | |
| return report | |
| def normalize_depth_relative_to_nearest_point(depth_map, wound_mask): | |
| """ | |
| Normalize depth map relative to the nearest point in the wound area | |
| This assumes a top-down camera perspective where the closest point to camera = 0 depth | |
| Args: | |
| depth_map: Raw depth map | |
| wound_mask: Binary mask of wound region | |
| Returns: | |
| normalized_depth: Depth values relative to nearest point (0 = nearest, positive = deeper) | |
| nearest_point_coords: Coordinates of the nearest point | |
| max_relative_depth: Maximum relative depth in the wound | |
| """ | |
| if depth_map is None or wound_mask is None: | |
| return depth_map, None, 0 | |
| # Convert mask to binary | |
| binary_mask = (wound_mask > 127).astype(np.uint8) | |
| # Find wound region coordinates | |
| wound_coords = np.where(binary_mask > 0) | |
| if len(wound_coords[0]) == 0: | |
| return depth_map, None, 0 | |
| # Get depth values only for wound region | |
| wound_depths = depth_map[wound_coords] | |
| # Find the nearest point (minimum depth value in wound region) | |
| nearest_depth = np.min(wound_depths) | |
| nearest_indices = np.where(wound_depths == nearest_depth) | |
| # Get coordinates of the nearest point(s) | |
| nearest_point_coords = (wound_coords[0][nearest_indices[0][0]], | |
| wound_coords[1][nearest_indices[0][0]]) | |
| # Create normalized depth map (relative to nearest point) | |
| normalized_depth = depth_map.copy() | |
| normalized_depth = normalized_depth - nearest_depth | |
| # Ensure all values are non-negative (nearest point = 0, others = positive) | |
| normalized_depth = np.maximum(normalized_depth, 0) | |
| # Calculate maximum relative depth in wound region | |
| wound_normalized_depths = normalized_depth[wound_coords] | |
| max_relative_depth = np.max(wound_normalized_depths) | |
| return normalized_depth, nearest_point_coords, max_relative_depth | |
| def calibrate_depth_map(depth_map, reference_depth_mm=10.0): | |
| """ | |
| Calibrate depth map to real-world measurements using reference depth | |
| This helps convert normalized depth values to actual millimeters | |
| """ | |
| if depth_map is None: | |
| return depth_map | |
| # Find the maximum depth value in the depth map | |
| max_depth_value = np.max(depth_map) | |
| min_depth_value = np.min(depth_map) | |
| if max_depth_value == min_depth_value: | |
| return depth_map | |
| # Apply calibration to convert to millimeters | |
| # Assuming the maximum depth in the map corresponds to reference_depth_mm | |
| calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm | |
| return calibrated_depth | |
| def create_depth_analysis_visualization(depth_map, wound_mask, nearest_point_coords, max_relative_depth): | |
| """ | |
| Create a visualization showing the depth analysis with nearest point and deepest point highlighted | |
| """ | |
| if depth_map is None or wound_mask is None: | |
| return None | |
| # Create a copy of the depth map for visualization | |
| vis_depth = depth_map.copy() | |
| # Apply colormap for better visualization | |
| normalized_depth = (vis_depth - np.min(vis_depth)) / (np.max(vis_depth) - np.min(vis_depth)) | |
| colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(normalized_depth)[:, :, :3] * 255).astype(np.uint8) | |
| # Convert to RGB if grayscale | |
| if len(colored_depth.shape) == 3 and colored_depth.shape[2] == 1: | |
| colored_depth = cv2.cvtColor(colored_depth, cv2.COLOR_GRAY2RGB) | |
| # Highlight the nearest point (reference point) with a red circle | |
| if nearest_point_coords is not None: | |
| y, x = nearest_point_coords | |
| cv2.circle(colored_depth, (x, y), 10, (255, 0, 0), 2) # Red circle for nearest point | |
| cv2.putText(colored_depth, "REF", (x+15, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) | |
| # Find and highlight the deepest point | |
| binary_mask = (wound_mask > 127).astype(np.uint8) | |
| wound_coords = np.where(binary_mask > 0) | |
| if len(wound_coords[0]) > 0: | |
| # Get depth values for wound region | |
| wound_depths = vis_depth[wound_coords] | |
| max_depth_idx = np.argmax(wound_depths) | |
| deepest_point_coords = (wound_coords[0][max_depth_idx], wound_coords[1][max_depth_idx]) | |
| # Highlight the deepest point with a blue circle | |
| y, x = deepest_point_coords | |
| cv2.circle(colored_depth, (x, y), 12, (0, 0, 255), 3) # Blue circle for deepest point | |
| cv2.putText(colored_depth, "DEEP", (x+15, y+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) | |
| # Overlay wound mask outline | |
| contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(colored_depth, contours, -1, (0, 255, 0), 2) # Green outline for wound boundary | |
| return colored_depth | |
| def get_enhanced_severity_description(severity): | |
| """Get comprehensive medical description for severity level""" | |
| descriptions = { | |
| "Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.", | |
| "Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.", | |
| "Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.", | |
| "Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.", | |
| "Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.", | |
| "Unknown": "Unable to determine severity due to insufficient data or poor image quality." | |
| } | |
| return descriptions.get(severity, "Severity assessment unavailable.") | |
| def create_sample_wound_mask(image_shape, center=None, radius=50): | |
| """Create a sample circular wound mask for testing""" | |
| if center is None: | |
| center = (image_shape[1] // 2, image_shape[0] // 2) | |
| mask = np.zeros(image_shape[:2], dtype=np.uint8) | |
| y, x = np.ogrid[:image_shape[0], :image_shape[1]] | |
| # Create circular mask | |
| dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) | |
| mask[dist_from_center <= radius] = 255 | |
| return mask | |
| def create_realistic_wound_mask(image_shape, method='elliptical'): | |
| """Create a more realistic wound mask with irregular shapes""" | |
| h, w = image_shape[:2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| if method == 'elliptical': | |
| # Create elliptical wound mask | |
| center = (w // 2, h // 2) | |
| radius_x = min(w, h) // 3 | |
| radius_y = min(w, h) // 4 | |
| y, x = np.ogrid[:h, :w] | |
| # Add some irregularity to make it more realistic | |
| ellipse = ((x - center[0])**2 / (radius_x**2) + | |
| (y - center[1])**2 / (radius_y**2)) <= 1 | |
| # Add some noise and irregularity | |
| noise = np.random.random((h, w)) > 0.8 | |
| mask = (ellipse | noise).astype(np.uint8) * 255 | |
| elif method == 'irregular': | |
| # Create irregular wound mask | |
| center = (w // 2, h // 2) | |
| radius = min(w, h) // 4 | |
| y, x = np.ogrid[:h, :w] | |
| base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius | |
| # Add irregular extensions | |
| extensions = np.zeros_like(base_circle) | |
| for i in range(3): | |
| angle = i * 2 * np.pi / 3 | |
| ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) | |
| ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) | |
| ext_radius = radius // 3 | |
| ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius | |
| extensions = extensions | ext_circle | |
| mask = (base_circle | extensions).astype(np.uint8) * 255 | |
| # Apply morphological operations to smooth the mask | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| return mask | |
| # --- Depth Estimation Functions --- | |
| def predict_depth(image): | |
| return depth_model.infer_image(image) | |
| def calculate_max_points(image): | |
| """Calculate maximum points based on image dimensions (3x pixel count)""" | |
| if image is None: | |
| return 10000 # Default value | |
| h, w = image.shape[:2] | |
| max_points = h * w * 3 | |
| # Ensure minimum and reasonable maximum values | |
| return max(1000, min(max_points, 300000)) | |
| def update_slider_on_image_upload(image): | |
| """Update the points slider when an image is uploaded""" | |
| max_points = calculate_max_points(image) | |
| default_value = min(10000, max_points // 10) # 10% of max points as default | |
| return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, | |
| label=f"Number of 3D points (max: {max_points:,})") | |
| def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): | |
| """Create a point cloud from depth map using camera intrinsics with high detail""" | |
| h, w = depth_map.shape | |
| # Use smaller step for higher detail (reduced downsampling) | |
| step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail | |
| # Create mesh grid for camera coordinates | |
| y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
| # Convert to camera coordinates (normalized by focal length) | |
| x_cam = (x_coords - w / 2) / focal_length_x | |
| y_cam = (y_coords - h / 2) / focal_length_y | |
| # Get depth values | |
| depth_values = depth_map[::step, ::step] | |
| # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) | |
| x_3d = x_cam * depth_values | |
| y_3d = y_cam * depth_values | |
| z_3d = depth_values | |
| # Flatten arrays | |
| points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) | |
| # Get corresponding image colors | |
| image_colors = image[::step, ::step, :] | |
| colors = image_colors.reshape(-1, 3) / 255.0 | |
| # Create Open3D point cloud | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(points) | |
| pcd.colors = o3d.utility.Vector3dVector(colors) | |
| return pcd | |
| def reconstruct_surface_mesh_from_point_cloud(pcd): | |
| """Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" | |
| # Estimate and orient normals with high precision | |
| pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) | |
| pcd.orient_normals_consistent_tangent_plane(k=50) | |
| # Create surface mesh with maximum detail (depth=12 for very high resolution) | |
| mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) | |
| # Return mesh without filtering low-density vertices | |
| return mesh | |
| def create_enhanced_3d_visualization(image, depth_map, max_points=10000): | |
| """Create an enhanced 3D visualization using proper camera projection""" | |
| h, w = depth_map.shape | |
| # Downsample to avoid too many points for performance | |
| step = max(1, int(np.sqrt(h * w / max_points))) | |
| # Create mesh grid for camera coordinates | |
| y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
| # Convert to camera coordinates (normalized by focal length) | |
| focal_length = 470.4 # Default focal length | |
| x_cam = (x_coords - w / 2) / focal_length | |
| y_cam = (y_coords - h / 2) / focal_length | |
| # Get depth values | |
| depth_values = depth_map[::step, ::step] | |
| # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) | |
| x_3d = x_cam * depth_values | |
| y_3d = y_cam * depth_values | |
| z_3d = depth_values | |
| # Flatten arrays | |
| x_flat = x_3d.flatten() | |
| y_flat = y_3d.flatten() | |
| z_flat = z_3d.flatten() | |
| # Get corresponding image colors | |
| image_colors = image[::step, ::step, :] | |
| colors_flat = image_colors.reshape(-1, 3) | |
| # Convert colors to Plotly-compatible rgb strings | |
| colors_rgb = [f'rgb({r},{g},{b})' for r, g, b in colors_flat] | |
| # Create 3D scatter plot with proper camera projection | |
| fig = go.Figure(data=[go.Scatter3d( | |
| x=x_flat, | |
| y=y_flat, | |
| z=z_flat, | |
| mode='markers', | |
| marker=dict( | |
| size=1.5, | |
| color=colors_rgb, | |
| opacity=0.9 | |
| ), | |
| hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' + | |
| '<b>Depth:</b> %{z:.2f}<br>' + | |
| '<extra></extra>' | |
| )]) | |
| fig.update_layout( | |
| title="3D Point Cloud Visualization (Camera Projection)", | |
| scene=dict( | |
| xaxis_title="X (meters)", | |
| yaxis_title="Y (meters)", | |
| zaxis_title="Z (meters)", | |
| camera=dict( | |
| eye=dict(x=2.0, y=2.0, z=2.0), | |
| center=dict(x=0, y=0, z=0), | |
| up=dict(x=0, y=0, z=1) | |
| ), | |
| aspectmode='data' | |
| ), | |
| width=700, | |
| height=600 | |
| ) | |
| return fig | |
| def on_depth_submit(image, num_points, focal_x, focal_y): | |
| original_image = image.copy() | |
| h, w = image.shape[:2] | |
| # Predict depth using the model | |
| depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed | |
| # Save raw 16-bit depth | |
| raw_depth = Image.fromarray(depth.astype('uint16')) | |
| tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
| raw_depth.save(tmp_raw_depth.name) | |
| # Normalize and convert to grayscale for display | |
| norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
| norm_depth = norm_depth.astype(np.uint8) | |
| colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) | |
| gray_depth = Image.fromarray(norm_depth) | |
| tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
| gray_depth.save(tmp_gray_depth.name) | |
| # Create point cloud | |
| pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) | |
| # Reconstruct mesh from point cloud | |
| mesh = reconstruct_surface_mesh_from_point_cloud(pcd) | |
| # Save mesh with faces as .ply | |
| tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) | |
| o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) | |
| # Create enhanced 3D scatter plot visualization | |
| depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) | |
| return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] | |
| # --- Actual Wound Segmentation Functions --- | |
| def create_automatic_wound_mask(image, method='deep_learning'): | |
| """ | |
| Automatically generate wound mask from image using the actual deep learning model | |
| Args: | |
| image: Input image (numpy array) | |
| method: Segmentation method (currently only 'deep_learning' supported) | |
| Returns: | |
| mask: Binary wound mask | |
| """ | |
| if image is None: | |
| return None | |
| # Use the actual deep learning model for segmentation | |
| if method == 'deep_learning': | |
| mask, _ = segmentation_model.segment_wound(image) | |
| return mask | |
| else: | |
| # Fallback to deep learning if method not recognized | |
| mask, _ = segmentation_model.segment_wound(image) | |
| return mask | |
| def post_process_wound_mask(mask, min_area=100): | |
| """Post-process the wound mask to remove noise and small objects""" | |
| if mask is None: | |
| return None | |
| # Convert to binary if needed | |
| if mask.dtype != np.uint8: | |
| mask = mask.astype(np.uint8) | |
| # Apply morphological operations to clean up | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| # Remove small objects using OpenCV | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| mask_clean = np.zeros_like(mask) | |
| for contour in contours: | |
| area = cv2.contourArea(contour) | |
| if area >= min_area: | |
| cv2.fillPoly(mask_clean, [contour], 255) | |
| # Fill holes | |
| mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) | |
| return mask_clean | |
| def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'): | |
| """Analyze wound severity with automatic mask generation using actual segmentation model""" | |
| if image is None or depth_map is None: | |
| return "β Please provide both image and depth map." | |
| # Generate automatic wound mask using the actual model | |
| auto_mask = create_automatic_wound_mask(image, method=segmentation_method) | |
| if auto_mask is None: | |
| return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded." | |
| # Post-process the mask | |
| processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
| if processed_mask is None or np.sum(processed_mask > 0) == 0: | |
| return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask." | |
| # Analyze severity using the automatic mask | |
| return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) | |
| # --- Main Gradio Interface --- | |
| with gr.Blocks(css=css, title="Wound Analysis System") as demo: | |
| gr.HTML("<h1>Wound Analysis System</h1>") | |
| #gr.Markdown("### Complete workflow: Classification β Depth Estimation β Wound Severity Analysis") | |
| # Shared states | |
| shared_image = gr.State() | |
| shared_depth_map = gr.State() | |
| with gr.Tabs(): | |
| # Tab 1: Wound Classification | |
| with gr.Tab("1. π Wound Classification & Initial Analysis"): | |
| gr.Markdown("### Step 1: Classify wound type and get initial AI analysis") | |
| #gr.Markdown("Upload an image to identify the wound type and receive detailed analysis from our Vision AI.") | |
| with gr.Row(): | |
| # Left Column - Image Upload | |
| with gr.Column(scale=1): | |
| gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 0; font-weight: bold; font-size: 1.8rem;">Upload Wound Image</h2>') | |
| classification_image_input = gr.Image( | |
| label="", | |
| type="pil", | |
| height=400 | |
| ) | |
| # Place Clear and Analyse buttons side by side | |
| with gr.Row(): | |
| classify_clear_btn = gr.Button( | |
| "Clear", | |
| variant="secondary", | |
| size="lg", | |
| scale=1 | |
| ) | |
| analyse_btn = gr.Button( | |
| "Analyse", | |
| variant="primary", | |
| size="lg", | |
| scale=1 | |
| ) | |
| # Right Column - Classification Results | |
| with gr.Column(scale=1): | |
| gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 0; font-weight: bold; font-size: 1.8rem;">Classification Results</h2>') | |
| classification_output = gr.Label( | |
| label="", | |
| num_top_classes=5, | |
| show_label=False | |
| ) | |
| # Second Row - Full Width AI Analysis | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 2rem; margin-bottom: 1rem; font-weight: bold; font-size: 1.8rem;">Wound Visual Analysis</h2>') | |
| gemini_output = gr.HTML( | |
| value=""" | |
| <div style=" | |
| border-radius: 12px; | |
| padding: 20px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
| font-family: Arial, sans-serif; | |
| min-height: 200px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| color: white; | |
| width: 100%; | |
| border-left: 4px solid #d97706; | |
| font-weight: bold; | |
| "> | |
| Upload an image to get AI-powered wound analysis | |
| </div> | |
| """ | |
| ) | |
| # Event handlers for classification tab | |
| classify_clear_btn.click( | |
| fn=lambda: (None, None, """ | |
| <div style=" | |
| border-radius: 12px; | |
| padding: 20px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
| font-family: Arial, sans-serif; | |
| min-height: 200px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| color: white; | |
| width: 100%; | |
| border-left: 4px solid #d97706; | |
| font-weight: bold; | |
| "> | |
| Upload an image to get AI-powered wound analysis | |
| </div> | |
| """), | |
| inputs=None, | |
| outputs=[classification_image_input, classification_output, gemini_output] | |
| ) | |
| # Only run classification on image upload | |
| def classify_and_store(image): | |
| result = classify_wound(image) | |
| return result | |
| classification_image_input.change( | |
| fn=classify_and_store, | |
| inputs=classification_image_input, | |
| outputs=classification_output | |
| ) | |
| # Store image in shared state for next tabs | |
| def store_shared_image(image): | |
| return image | |
| classification_image_input.change( | |
| fn=store_shared_image, | |
| inputs=classification_image_input, | |
| outputs=shared_image | |
| ) | |
| # Run Gemini analysis only when Analyse button is clicked | |
| def run_gemini_on_click(image, classification): | |
| # Get top label | |
| if isinstance(classification, dict) and classification: | |
| top_label = max(classification.items(), key=lambda x: x[1])[0] | |
| else: | |
| top_label = "Unknown" | |
| gemini_analysis = analyze_wound_with_gemini(image, top_label) | |
| formatted_analysis = format_gemini_analysis(gemini_analysis) | |
| return formatted_analysis | |
| analyse_btn.click( | |
| fn=run_gemini_on_click, | |
| inputs=[classification_image_input, classification_output], | |
| outputs=gemini_output | |
| ) | |
| # Tab 2: Depth Estimation | |
| with gr.Tab("2. π Depth Estimation & 3D Visualization"): | |
| gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") | |
| gr.Markdown("This module creates depth maps and 3D point clouds from your images.") | |
| with gr.Row(): | |
| load_from_classification_btn = gr.Button("π Load Image from Classification Tab", variant="secondary") | |
| with gr.Row(): | |
| depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') | |
| depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') | |
| with gr.Row(): | |
| depth_submit = gr.Button(value="Compute Depth", variant="primary") | |
| points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, | |
| label="Number of 3D points (upload image to update max)") | |
| with gr.Row(): | |
| focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
| label="Focal Length X (pixels)") | |
| focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
| label="Focal Length Y (pixels)") | |
| # Reorganized layout: 2 columns - 3D visualization on left, file outputs stacked on right | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # 3D Visualization | |
| gr.Markdown("### 3D Point Cloud Visualization") | |
| gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") | |
| depth_3d_plot = gr.Plot(label="3D Point Cloud") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Download Files") | |
| gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") | |
| raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") | |
| point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") | |
| # Tab 3: Wound Severity Analysis | |
| with gr.Tab("3. π©Ή Wound Severity Analysis"): | |
| gr.Markdown("### Step 3: Analyze wound severity using depth maps") | |
| gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") | |
| with gr.Row(): | |
| # Load depth map from previous tab | |
| load_depth_btn = gr.Button("π Load Depth Map from Tab 2", variant="secondary") | |
| with gr.Row(): | |
| severity_input_image = gr.Image(label="Original Image", type='numpy') | |
| severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') | |
| with gr.Row(): | |
| wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy') | |
| with gr.Row(): | |
| severity_output = gr.HTML( | |
| label="π€ AI-Powered Medical Assessment", | |
| value=""" | |
| <div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'> | |
| <div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'> | |
| π©Ή Wound Severity Analysis | |
| </div> | |
| <div style='font-size: 18px; color: #cccccc; margin-bottom: 20px;'> | |
| β³ Waiting for Input... | |
| </div> | |
| <div style='color: #888888; font-size: 14px;'> | |
| Please upload an image and depth map, then click "π€ Analyze Severity with Auto-Generated Mask" to begin AI-powered medical assessment. | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.") | |
| with gr.Row(): | |
| auto_severity_button = gr.Button("π€ Analyze Severity with Auto-Generated Mask", variant="primary", size="lg") | |
| pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, | |
| label="Pixel Spacing (mm/pixel)") | |
| depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0, | |
| label="Depth Calibration (mm)", | |
| info="Adjust based on expected maximum wound depth") | |
| #gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") | |
| #gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.") | |
| #gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.") | |
| # Update slider when image is uploaded | |
| depth_input_image.change( | |
| fn=update_slider_on_image_upload, | |
| inputs=[depth_input_image], | |
| outputs=[points_slider] | |
| ) | |
| # Modified depth submit function to store depth map | |
| def on_depth_submit_with_state(image, num_points, focal_x, focal_y): | |
| results = on_depth_submit(image, num_points, focal_x, focal_y) | |
| # Extract depth map from results for severity analysis | |
| depth_map = None | |
| if image is not None: | |
| depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed | |
| # Normalize depth for severity analysis | |
| norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
| depth_map = norm_depth.astype(np.uint8) | |
| return results + [depth_map] | |
| depth_submit.click(on_depth_submit_with_state, | |
| inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], | |
| outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, shared_depth_map]) | |
| # Function to load image from classification to depth tab | |
| def load_image_from_classification(shared_img): | |
| if shared_img is None: | |
| return None, "β No image available from classification tab. Please upload an image in Tab 1 first." | |
| # Convert PIL image to numpy array for depth estimation | |
| if hasattr(shared_img, 'convert'): | |
| # It's a PIL image, convert to numpy | |
| img_array = np.array(shared_img) | |
| return img_array, "β Image loaded from classification tab successfully!" | |
| else: | |
| # Already numpy array | |
| return shared_img, "β Image loaded from classification tab successfully!" | |
| # Connect the load button | |
| load_from_classification_btn.click( | |
| fn=load_image_from_classification, | |
| inputs=shared_image, | |
| outputs=[depth_input_image, gr.HTML()] | |
| ) | |
| # Load depth map to severity tab and auto-generate mask | |
| def load_depth_to_severity(depth_map, original_image): | |
| if depth_map is None: | |
| return None, None, None, "β No depth map available. Please compute depth in Tab 2 first." | |
| # Auto-generate wound mask using segmentation model | |
| if original_image is not None: | |
| auto_mask, _ = segmentation_model.segment_wound(original_image) | |
| if auto_mask is not None: | |
| # Post-process the mask | |
| processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
| if processed_mask is not None and np.sum(processed_mask > 0) > 0: | |
| return depth_map, original_image, processed_mask, "β Depth map loaded and wound mask auto-generated!" | |
| else: | |
| return depth_map, original_image, None, "β Depth map loaded but no wound detected. Try uploading a different image." | |
| else: | |
| return depth_map, original_image, None, "β Depth map loaded but segmentation failed. Try uploading a different image." | |
| else: | |
| return depth_map, original_image, None, "β Depth map loaded successfully!" | |
| load_depth_btn.click( | |
| fn=load_depth_to_severity, | |
| inputs=[shared_depth_map, depth_input_image], | |
| outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()] | |
| ) | |
| # Loading state function | |
| def show_loading_state(): | |
| return """ | |
| <div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'> | |
| <div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'> | |
| π©Ή Wound Severity Analysis | |
| </div> | |
| <div style='font-size: 18px; color: #4CAF50; margin-bottom: 20px;'> | |
| π AI Analysis in Progress... | |
| </div> | |
| <div style='color: #cccccc; font-size: 14px; margin-bottom: 15px;'> | |
| β’ Generating wound mask with deep learning model<br> | |
| β’ Computing depth measurements and statistics<br> | |
| β’ Analyzing wound characteristics with Gemini AI<br> | |
| β’ Preparing comprehensive medical assessment | |
| </div> | |
| <div style='display: inline-block; width: 30px; height: 30px; border: 3px solid #f3f3f3; border-top: 3px solid #4CAF50; border-radius: 50%; animation: spin 1s linear infinite;'></div> | |
| <style> | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| </style> | |
| </div> | |
| """ | |
| # Automatic severity analysis function | |
| def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration): | |
| if depth_map is None: | |
| return """ | |
| <div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'> | |
| <div style='font-size: 24px; font-weight: bold; color: #f44336; margin-bottom: 15px;'> | |
| β Error | |
| </div> | |
| <div style='font-size: 16px; color: #cccccc;'> | |
| Please load depth map from Tab 1 first. | |
| </div> | |
| </div> | |
| """ | |
| # Generate automatic wound mask using the actual model | |
| auto_mask = create_automatic_wound_mask(image, method='deep_learning') | |
| if auto_mask is None: | |
| return """ | |
| <div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'> | |
| <div style='font-size: 24px; font-weight: bold; color: #f44336; margin-bottom: 15px;'> | |
| β Error | |
| </div> | |
| <div style='font-size: 16px; color: #cccccc;'> | |
| Failed to generate automatic wound mask. Please check if the segmentation model is loaded. | |
| </div> | |
| </div> | |
| """ | |
| # Post-process the mask with fixed minimum area | |
| processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
| if processed_mask is None or np.sum(processed_mask > 0) == 0: | |
| return """ | |
| <div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'> | |
| <div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'> | |
| β οΈ No Wound Detected | |
| </div> | |
| <div style='font-size: 16px; color: #cccccc;'> | |
| No wound region detected by the segmentation model. Try uploading a different image or use manual mask. | |
| </div> | |
| </div> | |
| """ | |
| # Analyze severity using the automatic mask | |
| return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration) | |
| # Connect event handler with loading state | |
| auto_severity_button.click( | |
| fn=show_loading_state, | |
| inputs=[], | |
| outputs=[severity_output] | |
| ).then( | |
| fn=run_auto_severity_analysis, | |
| inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider], | |
| outputs=[severity_output] | |
| ) | |
| # Auto-generate mask when image is uploaded | |
| def auto_generate_mask_on_image_upload(image): | |
| if image is None: | |
| return None, "β No image uploaded." | |
| # Generate automatic wound mask using segmentation model | |
| auto_mask, _ = segmentation_model.segment_wound(image) | |
| if auto_mask is not None: | |
| # Post-process the mask | |
| processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
| if processed_mask is not None and np.sum(processed_mask > 0) > 0: | |
| return processed_mask, "β Wound mask auto-generated using deep learning model!" | |
| else: | |
| return None, "β Image uploaded but no wound detected. Try uploading a different image." | |
| else: | |
| return None, "β Image uploaded but segmentation failed. Try uploading a different image." | |
| # Load shared image from classification tab | |
| def load_shared_image(shared_img): | |
| if shared_img is None: | |
| return gr.Image(), "β No image available from classification tab" | |
| # Convert PIL image to numpy array for depth estimation | |
| if hasattr(shared_img, 'convert'): | |
| # It's a PIL image, convert to numpy | |
| img_array = np.array(shared_img) | |
| return img_array, "β Image loaded from classification tab" | |
| else: | |
| # Already numpy array | |
| return shared_img, "β Image loaded from classification tab" | |
| # Auto-generate mask when image is uploaded to severity tab | |
| severity_input_image.change( | |
| fn=auto_generate_mask_on_image_upload, | |
| inputs=[severity_input_image], | |
| outputs=[wound_mask_input, gr.HTML()] | |
| ) | |
| if __name__ == '__main__': | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |