import glob import gradio as gr import matplotlib import numpy as np from PIL import Image import torch import tempfile from gradio_imageslider import ImageSlider import plotly.graph_objects as go import plotly.express as px import open3d as o3d from depth_anything_v2.dpt import DepthAnythingV2 import os import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image as keras_image import base64 from io import BytesIO import gdown import spaces # Define path and file ID checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" # Download if not already present if not os.path.exists(model_file): print("Downloading model from Google Drive...") gdown.download(gdrive_url, model_file, quiet=False) # --- TensorFlow: Check GPU Availability --- gpus = tf.config.list_physical_devices('GPU') if gpus: print("TensorFlow is using GPU") else: print("TensorFlow is using CPU") # --- Load Wound Classification Model and Class Labels --- wound_model = load_model("/home/user/app/keras_model.h5") with open("/home/user/app/labels.txt", "r") as f: class_labels = [line.strip().split(maxsplit=1)[1] for line in f] # --- PyTorch: Set Device and Load Depth Model --- map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") print(f"Using PyTorch device: {map_device}") model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} } encoder = 'vitl' depth_model = DepthAnythingV2(**model_configs[encoder]) state_dict = torch.load( f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth', map_location=map_device ) depth_model.load_state_dict(state_dict) depth_model = depth_model.to(map_device).eval() # --- Custom CSS for unified dark theme --- css = """ .gradio-container { font-family: 'Segoe UI', sans-serif; background-color: #121212; color: #ffffff; padding: 20px; } .gr-button { background-color: #2c3e50; color: white; border-radius: 10px; } .gr-button:hover { background-color: #34495e; } .gr-html, .gr-html div { white-space: normal !important; overflow: visible !important; text-overflow: unset !important; word-break: break-word !important; } #img-display-container { max-height: 100vh; } #img-display-input { max-height: 80vh; } #img-display-output { max-height: 80vh; } #download { height: 62px; } h1 { text-align: center; font-size: 3rem; font-weight: bold; margin: 2rem 0; color: #ffffff; } h2 { color: #ffffff; text-align: center; margin: 1rem 0; } .gr-tabs { background-color: #1e1e1e; border-radius: 10px; padding: 10px; } .gr-tab-nav { background-color: #2c3e50; border-radius: 8px; } .gr-tab-nav button { color: #ffffff !important; } .gr-tab-nav button.selected { background-color: #34495e !important; } """ # --- Wound Classification Functions --- def preprocess_input(img): img = img.resize((224, 224)) arr = keras_image.img_to_array(img) arr = arr / 255.0 return np.expand_dims(arr, axis=0) def get_reasoning_from_gemini(img, prediction): try: # For now, return a simple explanation without Gemini API to avoid typing issues # In production, you would implement the proper Gemini API call here explanations = { "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.", "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.", "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.", "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.", "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues." } return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.") except Exception as e: return f"(Reasoning unavailable: {str(e)})" @spaces.GPU def classify_wound_image(img): if img is None: return "
No image provided
", "" img_array = preprocess_input(img) predictions = wound_model.predict(img_array, verbose=0)[0] pred_idx = int(np.argmax(predictions)) pred_class = class_labels[pred_idx] # Get reasoning from Gemini reasoning_text = get_reasoning_from_gemini(img, pred_class) # Prediction Card predicted_card = f"""
Predicted Wound Type
{pred_class}
""" # Reasoning Card reasoning_card = f"""
Reasoning
{reasoning_text}
""" return predicted_card, reasoning_card # --- Wound Severity Estimation Functions --- @spaces.GPU def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5): """Compute area statistics for different depth regions""" pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 # Extract only wound region wound_mask = (mask > 127) wound_depths = depth_map[wound_mask] total_area = np.sum(wound_mask) * pixel_area_cm2 # Categorize depth regions shallow = wound_depths < 3 moderate = (wound_depths >= 3) & (wound_depths < 6) deep = wound_depths >= 6 shallow_area = np.sum(shallow) * pixel_area_cm2 moderate_area = np.sum(moderate) * pixel_area_cm2 deep_area = np.sum(deep) * pixel_area_cm2 deep_ratio = deep_area / total_area if total_area > 0 else 0 return { 'total_area_cm2': total_area, 'shallow_area_cm2': shallow_area, 'moderate_area_cm2': moderate_area, 'deep_area_cm2': deep_area, 'deep_ratio': deep_ratio, 'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0 } def classify_wound_severity_by_area(depth_stats): """Classify wound severity based on area and depth distribution""" total = depth_stats['total_area_cm2'] deep = depth_stats['deep_area_cm2'] moderate = depth_stats['moderate_area_cm2'] if total == 0: return "Unknown" # Severity classification rules if deep > 2 or (deep / total) > 0.3: return "Severe" elif moderate > 1.5 or (moderate / total) > 0.4: return "Moderate" else: return "Mild" def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5): """Analyze wound severity from depth map and wound mask""" if image is None or depth_map is None or wound_mask is None: return "❌ Please upload image, depth map, and wound mask." # Convert wound mask to grayscale if needed if len(wound_mask.shape) == 3: wound_mask = np.mean(wound_mask, axis=2) # Ensure depth map and mask have same dimensions if depth_map.shape[:2] != wound_mask.shape[:2]: # Resize mask to match depth map from PIL import Image mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) wound_mask = np.array(mask_pil) # Compute statistics stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm) severity = classify_wound_severity_by_area(stats) # Create severity report with color coding severity_color = { "Mild": "#4CAF50", # Green "Moderate": "#FF9800", # Orange "Severe": "#F44336" # Red }.get(severity, "#9E9E9E") # Gray for unknown report = f"""
🩹 Wound Severity Analysis
📏 Area Measurements
🟢 Total Area: {stats['total_area_cm2']:.2f} cm²
🟩 Shallow (0-3mm): {stats['shallow_area_cm2']:.2f} cm²
🟨 Moderate (3-6mm): {stats['moderate_area_cm2']:.2f} cm²
🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
📊 Depth Analysis
🔥 Deep Coverage: {stats['deep_ratio']*100:.1f}%
📏 Max Depth: {stats['max_depth']:.1f} mm
Pixel Spacing: {pixel_spacing_mm} mm
🎯 Predicted Severity: {severity}
{get_severity_description(severity)}
""" return report def get_severity_description(severity): """Get description for severity level""" descriptions = { "Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.", "Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.", "Severe": "Deep tissue damage requiring immediate medical attention and specialized care.", "Unknown": "Unable to determine severity due to insufficient data." } return descriptions.get(severity, "Severity assessment unavailable.") def create_sample_wound_mask(image_shape, center=None, radius=50): """Create a sample circular wound mask for testing""" if center is None: center = (image_shape[1] // 2, image_shape[0] // 2) mask = np.zeros(image_shape[:2], dtype=np.uint8) y, x = np.ogrid[:image_shape[0], :image_shape[1]] # Create circular mask dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) mask[dist_from_center <= radius] = 255 return mask def create_realistic_wound_mask(image_shape, method='elliptical'): """Create a more realistic wound mask with irregular shapes""" h, w = image_shape[:2] mask = np.zeros((h, w), dtype=np.uint8) if method == 'elliptical': # Create elliptical wound mask center = (w // 2, h // 2) radius_x = min(w, h) // 3 radius_y = min(w, h) // 4 y, x = np.ogrid[:h, :w] # Add some irregularity to make it more realistic ellipse = ((x - center[0])**2 / (radius_x**2) + (y - center[1])**2 / (radius_y**2)) <= 1 # Add some noise and irregularity noise = np.random.random((h, w)) > 0.8 mask = (ellipse | noise).astype(np.uint8) * 255 elif method == 'irregular': # Create irregular wound mask center = (w // 2, h // 2) radius = min(w, h) // 4 y, x = np.ogrid[:h, :w] base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius # Add irregular extensions extensions = np.zeros_like(base_circle) for i in range(3): angle = i * 2 * np.pi / 3 ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) ext_radius = radius // 3 ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius extensions = extensions | ext_circle mask = (base_circle | extensions).astype(np.uint8) * 255 # Apply morphological operations to smooth the mask kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) return mask # --- Depth Estimation Functions --- @spaces.GPU def predict_depth(image): return depth_model.infer_image(image) def calculate_max_points(image): """Calculate maximum points based on image dimensions (3x pixel count)""" if image is None: return 10000 # Default value h, w = image.shape[:2] max_points = h * w * 3 # Ensure minimum and reasonable maximum values return max(1000, min(max_points, 300000)) def update_slider_on_image_upload(image): """Update the points slider when an image is uploaded""" max_points = calculate_max_points(image) default_value = min(10000, max_points // 10) # 10% of max points as default return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, label=f"Number of 3D points (max: {max_points:,})") @spaces.GPU def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): """Create a point cloud from depth map using camera intrinsics with high detail""" h, w = depth_map.shape # Use smaller step for higher detail (reduced downsampling) step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail # Create mesh grid for camera coordinates y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] # Convert to camera coordinates (normalized by focal length) x_cam = (x_coords - w / 2) / focal_length_x y_cam = (y_coords - h / 2) / focal_length_y # Get depth values depth_values = depth_map[::step, ::step] # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values # Flatten arrays points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) # Get corresponding image colors image_colors = image[::step, ::step, :] colors = image_colors.reshape(-1, 3) / 255.0 # Create Open3D point cloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) pcd.colors = o3d.utility.Vector3dVector(colors) return pcd @spaces.GPU def reconstruct_surface_mesh_from_point_cloud(pcd): """Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" # Estimate and orient normals with high precision pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) pcd.orient_normals_consistent_tangent_plane(k=50) # Create surface mesh with maximum detail (depth=12 for very high resolution) mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) # Return mesh without filtering low-density vertices return mesh @spaces.GPU def create_enhanced_3d_visualization(image, depth_map, max_points=10000): """Create an enhanced 3D visualization using proper camera projection""" h, w = depth_map.shape # Downsample to avoid too many points for performance step = max(1, int(np.sqrt(h * w / max_points))) # Create mesh grid for camera coordinates y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] # Convert to camera coordinates (normalized by focal length) focal_length = 470.4 # Default focal length x_cam = (x_coords - w / 2) / focal_length y_cam = (y_coords - h / 2) / focal_length # Get depth values depth_values = depth_map[::step, ::step] # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values # Flatten arrays x_flat = x_3d.flatten() y_flat = y_3d.flatten() z_flat = z_3d.flatten() # Get corresponding image colors image_colors = image[::step, ::step, :] colors_flat = image_colors.reshape(-1, 3) # Create 3D scatter plot with proper camera projection fig = go.Figure(data=[go.Scatter3d( x=x_flat, y=y_flat, z=z_flat, mode='markers', marker=dict( size=1.5, color=colors_flat, opacity=0.9 ), hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' + 'Depth: %{z:.2f}
' + '' )]) fig.update_layout( title="3D Point Cloud Visualization (Camera Projection)", scene=dict( xaxis_title="X (meters)", yaxis_title="Y (meters)", zaxis_title="Z (meters)", camera=dict( eye=dict(x=2.0, y=2.0, z=2.0), center=dict(x=0, y=0, z=0), up=dict(x=0, y=0, z=1) ), aspectmode='data' ), width=700, height=600 ) return fig def on_depth_submit(image, num_points, focal_x, focal_y): original_image = image.copy() h, w = image.shape[:2] # Predict depth using the model depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed # Save raw 16-bit depth raw_depth = Image.fromarray(depth.astype('uint16')) tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) raw_depth.save(tmp_raw_depth.name) # Normalize and convert to grayscale for display norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 norm_depth = norm_depth.astype(np.uint8) colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) gray_depth = Image.fromarray(norm_depth) tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) gray_depth.save(tmp_gray_depth.name) # Create point cloud pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) # Reconstruct mesh from point cloud mesh = reconstruct_surface_mesh_from_point_cloud(pcd) # Save mesh with faces as .ply tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) # Create enhanced 3D scatter plot visualization depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] # --- Automatic Wound Mask Generation Functions --- import cv2 from skimage import filters, morphology, measure from skimage.segmentation import clear_border def create_automatic_wound_mask(image, method='adaptive'): """ Automatically generate wound mask from image using various segmentation methods Args: image: Input image (numpy array) method: Segmentation method ('adaptive', 'otsu', 'color', 'combined') Returns: mask: Binary wound mask """ if image is None: return None # Convert to grayscale if needed if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: gray = image.copy() # Apply different segmentation methods if method == 'adaptive': mask = adaptive_threshold_segmentation(gray) elif method == 'otsu': mask = otsu_threshold_segmentation(gray) elif method == 'color': mask = color_based_segmentation(image) elif method == 'combined': mask = combined_segmentation(image, gray) else: mask = adaptive_threshold_segmentation(gray) return mask def adaptive_threshold_segmentation(gray): """Use adaptive thresholding for wound segmentation""" # Apply Gaussian blur to reduce noise blurred = cv2.GaussianBlur(gray, (15, 15), 0) # Adaptive thresholding with larger block size thresh = cv2.adaptiveThreshold( blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5 ) # Morphological operations to clean up the mask kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Find contours and keep only the largest ones contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Create a new mask with only large contours mask_clean = np.zeros_like(mask) for contour in contours: area = cv2.contourArea(contour) if area > 1000: # Minimum area threshold cv2.fillPoly(mask_clean, [contour], 255) return mask_clean def otsu_threshold_segmentation(gray): """Use Otsu's thresholding for wound segmentation""" # Apply Gaussian blur blurred = cv2.GaussianBlur(gray, (15, 15), 0) # Otsu's thresholding _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) # Morphological operations kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Find contours and keep only the largest ones contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Create a new mask with only large contours mask_clean = np.zeros_like(mask) for contour in contours: area = cv2.contourArea(contour) if area > 800: # Minimum area threshold cv2.fillPoly(mask_clean, [contour], 255) return mask_clean def color_based_segmentation(image): """Use color-based segmentation for wound detection""" # Convert to different color spaces hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) # Create masks for different color ranges (wound-like colors) # Reddish/brownish wound colors in HSV - broader ranges lower_red1 = np.array([0, 30, 30]) upper_red1 = np.array([15, 255, 255]) lower_red2 = np.array([160, 30, 30]) upper_red2 = np.array([180, 255, 255]) mask1 = cv2.inRange(hsv, lower_red1, upper_red1) mask2 = cv2.inRange(hsv, lower_red2, upper_red2) red_mask = mask1 + mask2 # Yellowish wound colors - broader range lower_yellow = np.array([15, 30, 30]) upper_yellow = np.array([35, 255, 255]) yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow) # Brownish wound colors lower_brown = np.array([10, 50, 20]) upper_brown = np.array([20, 255, 200]) brown_mask = cv2.inRange(hsv, lower_brown, upper_brown) # Combine color masks color_mask = red_mask + yellow_mask + brown_mask # Clean up the mask with larger kernels kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel) color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel) # Find contours and keep only the largest ones contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Create a new mask with only large contours mask_clean = np.zeros_like(color_mask) for contour in contours: area = cv2.contourArea(contour) if area > 600: # Minimum area threshold cv2.fillPoly(mask_clean, [contour], 255) return mask_clean def combined_segmentation(image, gray): """Combine multiple segmentation methods for better results""" # Get masks from different methods adaptive_mask = adaptive_threshold_segmentation(gray) otsu_mask = otsu_threshold_segmentation(gray) color_mask = color_based_segmentation(image) # Combine masks (union) combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask) combined_mask = cv2.bitwise_or(combined_mask, color_mask) # Apply additional morphological operations to clean up kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20)) combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel) # Find contours and keep only the largest ones contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Create a new mask with only large contours mask_clean = np.zeros_like(combined_mask) for contour in contours: area = cv2.contourArea(contour) if area > 500: # Minimum area threshold cv2.fillPoly(mask_clean, [contour], 255) # If no large contours found, create a realistic wound mask if np.sum(mask_clean) == 0: mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical') return mask_clean def post_process_wound_mask(mask, min_area=100): """Post-process the wound mask to remove noise and small objects""" if mask is None: return None # Convert to binary if needed if mask.dtype != np.uint8: mask = mask.astype(np.uint8) # Apply morphological operations to clean up kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Remove small objects using OpenCV contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) mask_clean = np.zeros_like(mask) for contour in contours: area = cv2.contourArea(contour) if area >= min_area: cv2.fillPoly(mask_clean, [contour], 255) # Fill holes mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) return mask_clean def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='combined'): """Analyze wound severity with automatic mask generation""" if image is None or depth_map is None: return "❌ Please provide both image and depth map." # Generate automatic wound mask auto_mask = create_automatic_wound_mask(image, method=segmentation_method) if auto_mask is None: return "❌ Failed to generate automatic wound mask." # Post-process the mask processed_mask = post_process_wound_mask(auto_mask, min_area=500) if processed_mask is None or np.sum(processed_mask > 0) == 0: return "❌ No wound region detected. Try adjusting segmentation parameters or upload a manual mask." # Analyze severity using the automatic mask return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) # --- Main Gradio Interface --- with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo: gr.HTML("

Wound Analysis & Depth Estimation System

") gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities") # Shared image state shared_image = gr.State() with gr.Tabs(): # Tab 1: Wound Classification with gr.Tab("1. Wound Classification"): gr.Markdown("### Step 1: Upload and classify your wound image") gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.") with gr.Row(): with gr.Column(scale=1): wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350) with gr.Column(scale=1): wound_prediction_box = gr.HTML() wound_reasoning_box = gr.HTML() # Button to pass image to depth estimation with gr.Row(): pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg") pass_status = gr.HTML("") wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input, outputs=[wound_prediction_box, wound_reasoning_box]) # Store image when uploaded for classification wound_image_input.change( fn=lambda img: img, inputs=[wound_image_input], outputs=[shared_image] ) # Tab 2: Depth Estimation with gr.Tab("2. Depth Estimation & 3D Visualization"): gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") gr.Markdown("This module creates depth maps and 3D point clouds from your images.") with gr.Row(): depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') with gr.Row(): depth_submit = gr.Button(value="Compute Depth", variant="primary") load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary") points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, label="Number of 3D points (upload image to update max)") with gr.Row(): focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length X (pixels)") focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length Y (pixels)") with gr.Row(): gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") # 3D Visualization gr.Markdown("### 3D Point Cloud Visualization") gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") depth_3d_plot = gr.Plot(label="3D Point Cloud") # Store depth map for severity analysis depth_map_state = gr.State() # Tab 3: Wound Severity Analysis with gr.Tab("3. 🩹 Wound Severity Analysis"): gr.Markdown("### Step 3: Analyze wound severity using depth maps") gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") with gr.Row(): severity_input_image = gr.Image(label="Original Image", type='numpy') severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') with gr.Row(): wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy') severity_output = gr.HTML(label="Severity Analysis Report") gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.") with gr.Row(): auto_severity_button = gr.Button("🤖 Auto-Analyze Severity", variant="primary", size="lg") manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg") pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, label="Pixel Spacing (mm/pixel)") gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") with gr.Row(): segmentation_method = gr.Dropdown( choices=["combined", "adaptive", "otsu", "color"], value="combined", label="Segmentation Method", info="Choose automatic segmentation method" ) min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100, label="Minimum Area (pixels)", info="Minimum wound area to detect") with gr.Row(): # Load depth map from previous tab load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary") sample_mask_btn = gr.Button("🎯 Generate Sample Mask", variant="secondary") realistic_mask_btn = gr.Button("🏥 Generate Realistic Mask", variant="secondary") preview_mask_btn = gr.Button("👁️ Preview Auto Mask", variant="secondary") gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.") # Generate sample mask function def generate_sample_mask(image): if image is None: return None, "❌ Please load an image first." sample_mask = create_sample_wound_mask(image.shape) return sample_mask, "✅ Sample circular wound mask generated!" # Generate realistic mask function def generate_realistic_mask(image): if image is None: return None, "❌ Please load an image first." realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical') return realistic_mask, "✅ Realistic elliptical wound mask generated!" sample_mask_btn.click( fn=generate_sample_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()] ) realistic_mask_btn.click( fn=generate_realistic_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()] ) # Update slider when image is uploaded depth_input_image.change( fn=update_slider_on_image_upload, inputs=[depth_input_image], outputs=[points_slider] ) # Modified depth submit function to store depth map def on_depth_submit_with_state(image, num_points, focal_x, focal_y): results = on_depth_submit(image, num_points, focal_x, focal_y) # Extract depth map from results for severity analysis depth_map = None if image is not None: depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed # Normalize depth for severity analysis norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth_map = norm_depth.astype(np.uint8) return results + [depth_map] depth_submit.click(on_depth_submit_with_state, inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state]) # Load depth map to severity tab def load_depth_to_severity(depth_map, original_image): if depth_map is None: return None, None, "❌ No depth map available. Please compute depth in Tab 2 first." return depth_map, original_image, "✅ Depth map loaded successfully!" load_depth_btn.click( fn=load_depth_to_severity, inputs=[depth_map_state, depth_input_image], outputs=[severity_depth_map, severity_input_image, gr.HTML()] ) # Automatic severity analysis function def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area): if depth_map is None: return "❌ Please load depth map from Tab 2 first." # Update post-processing with user-defined minimum area def post_process_with_area(mask): return post_process_wound_mask(mask, min_area=min_area) # Generate automatic wound mask auto_mask = create_automatic_wound_mask(image, method=seg_method) if auto_mask is None: return "❌ Failed to generate automatic wound mask." # Post-process the mask processed_mask = post_process_with_area(auto_mask) if processed_mask is None or np.sum(processed_mask > 0) == 0: return "❌ No wound region detected. Try adjusting segmentation parameters or use manual mask." # Analyze severity using the automatic mask return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing) # Manual severity analysis function def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing): if depth_map is None: return "❌ Please load depth map from Tab 2 first." if wound_mask is None: return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)." return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing) # Preview automatic mask function def preview_auto_mask(image, seg_method, min_area): if image is None: return None, "❌ Please load an image first." # Generate automatic wound mask auto_mask = create_automatic_wound_mask(image, method=seg_method) if auto_mask is None: return None, "❌ Failed to generate automatic wound mask." # Post-process the mask processed_mask = post_process_wound_mask(auto_mask, min_area=min_area) if processed_mask is None or np.sum(processed_mask > 0) == 0: return None, "❌ No wound region detected. Try adjusting parameters." return processed_mask, f"✅ Auto mask generated using {seg_method} method!" # Connect event handlers auto_severity_button.click( fn=run_auto_severity_analysis, inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, segmentation_method, min_area_slider], outputs=[severity_output] ) manual_severity_button.click( fn=run_manual_severity_analysis, inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider], outputs=[severity_output] ) preview_mask_btn.click( fn=preview_auto_mask, inputs=[severity_input_image, segmentation_method, min_area_slider], outputs=[wound_mask_input, gr.HTML()] ) # Load shared image from classification tab def load_shared_image(shared_img): if shared_img is None: return gr.Image(), "❌ No image available from classification tab" # Convert PIL image to numpy array for depth estimation if hasattr(shared_img, 'convert'): # It's a PIL image, convert to numpy img_array = np.array(shared_img) return img_array, "✅ Image loaded from classification tab" else: # Already numpy array return shared_img, "✅ Image loaded from classification tab" load_shared_btn.click( fn=load_shared_image, inputs=[shared_image], outputs=[depth_input_image, gr.HTML()] ) # Pass image to depth tab function def pass_image_to_depth(img): if img is None: return "❌ No image uploaded in classification tab" return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'" pass_to_depth_btn.click( fn=pass_image_to_depth, inputs=[shared_image], outputs=[pass_status] ) if __name__ == '__main__': demo.queue().launch( server_name="0.0.0.0", server_port=7860, share=True )