Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import gc | |
| import time | |
| from typing import Dict, Any, Optional, Callable | |
| from dataclasses import dataclass, field | |
| from threading import Lock | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class ModelInfo: | |
| """Information about a registered model.""" | |
| name: str | |
| loader: Callable[[], Any] | |
| is_critical: bool = False # Critical models are not unloaded under memory pressure | |
| estimated_memory_gb: float = 0.0 | |
| is_loaded: bool = False | |
| last_used: float = 0.0 | |
| model_instance: Any = None | |
| class ModelManager: | |
| """ | |
| Singleton model manager for unified model lifecycle management. | |
| Handles lazy loading, caching, and intelligent memory management. | |
| """ | |
| _instance = None | |
| _lock = Lock() | |
| def __new__(cls): | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._models: Dict[str, ModelInfo] = {} | |
| self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage | |
| self._device = self._detect_device() | |
| logger.info(f"🔧 ModelManager initialized on {self._device}") | |
| self._initialized = True | |
| def _detect_device(self) -> str: | |
| """Detect best available device.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def register_model( | |
| self, | |
| name: str, | |
| loader: Callable[[], Any], | |
| is_critical: bool = False, | |
| estimated_memory_gb: float = 0.0 | |
| ): | |
| """ | |
| Register a model for managed loading. | |
| Args: | |
| name: Unique model identifier | |
| loader: Callable that returns the loaded model | |
| is_critical: If True, model won't be unloaded under memory pressure | |
| estimated_memory_gb: Estimated GPU memory usage in GB | |
| """ | |
| if name in self._models: | |
| logger.warning(f"⚠️ Model '{name}' already registered, updating") | |
| self._models[name] = ModelInfo( | |
| name=name, | |
| loader=loader, | |
| is_critical=is_critical, | |
| estimated_memory_gb=estimated_memory_gb, | |
| is_loaded=False, | |
| last_used=0.0, | |
| model_instance=None | |
| ) | |
| logger.info(f"📝 Registered model: {name} (critical={is_critical}, ~{estimated_memory_gb:.1f}GB)") | |
| def load_model(self, name: str) -> Any: | |
| """ | |
| Load a model by name. Returns cached instance if already loaded. | |
| Args: | |
| name: Model identifier | |
| Returns: | |
| Loaded model instance | |
| Raises: | |
| KeyError: If model not registered | |
| RuntimeError: If loading fails | |
| """ | |
| if name not in self._models: | |
| raise KeyError(f"Model '{name}' not registered") | |
| model_info = self._models[name] | |
| # Return cached instance | |
| if model_info.is_loaded and model_info.model_instance is not None: | |
| model_info.last_used = time.time() | |
| logger.debug(f"📦 Using cached model: {name}") | |
| return model_info.model_instance | |
| # Check memory pressure before loading | |
| self.check_memory_pressure() | |
| # Load the model | |
| try: | |
| logger.info(f"📥 Loading model: {name}") | |
| start_time = time.time() | |
| model_instance = model_info.loader() | |
| model_info.model_instance = model_instance | |
| model_info.is_loaded = True | |
| model_info.last_used = time.time() | |
| load_time = time.time() - start_time | |
| logger.info(f"✅ Model '{name}' loaded in {load_time:.1f}s") | |
| return model_instance | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load model '{name}': {e}") | |
| raise RuntimeError(f"Model loading failed: {e}") | |
| def unload_model(self, name: str): | |
| """ | |
| Unload a specific model to free memory. | |
| Args: | |
| name: Model identifier | |
| """ | |
| if name not in self._models: | |
| return | |
| model_info = self._models[name] | |
| if not model_info.is_loaded: | |
| return | |
| try: | |
| logger.info(f"🗑️ Unloading model: {name}") | |
| # Delete model instance | |
| if model_info.model_instance is not None: | |
| del model_info.model_instance | |
| model_info.model_instance = None | |
| model_info.is_loaded = False | |
| # Cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info(f"✅ Model '{name}' unloaded") | |
| except Exception as e: | |
| logger.error(f"❌ Error unloading model '{name}': {e}") | |
| def check_memory_pressure(self) -> bool: | |
| """ | |
| Check GPU memory usage and unload least-used non-critical models if needed. | |
| Returns: | |
| True if cleanup was performed | |
| """ | |
| if not torch.cuda.is_available(): | |
| return False | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| usage_ratio = allocated / total | |
| if usage_ratio < self._memory_threshold: | |
| return False | |
| logger.warning(f"⚠️ Memory pressure detected: {usage_ratio:.1%} used") | |
| # Find non-critical models sorted by last used time | |
| unloadable = [ | |
| (name, info) for name, info in self._models.items() | |
| if info.is_loaded and not info.is_critical | |
| ] | |
| unloadable.sort(key=lambda x: x[1].last_used) | |
| # Unload oldest non-critical models | |
| cleaned = False | |
| for name, info in unloadable: | |
| self.unload_model(name) | |
| cleaned = True | |
| # Re-check memory | |
| new_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory | |
| if new_ratio < self._memory_threshold * 0.7: # Target 70% of threshold | |
| break | |
| return cleaned | |
| def force_cleanup(self): | |
| """Force cleanup all non-critical models and clear caches.""" | |
| logger.info("🧹 Force cleanup initiated") | |
| # Unload all non-critical models | |
| for name, info in self._models.items(): | |
| if info.is_loaded and not info.is_critical: | |
| self.unload_model(name) | |
| # Aggressive garbage collection | |
| for _ in range(5): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| torch.cuda.synchronize() | |
| logger.info("✅ Force cleanup completed") | |
| def get_memory_status(self) -> Dict[str, Any]: | |
| """ | |
| Get detailed memory status. | |
| Returns: | |
| Dictionary with memory statistics | |
| """ | |
| status = { | |
| "device": self._device, | |
| "models": {}, | |
| "total_estimated_gb": 0.0 | |
| } | |
| # Model status | |
| for name, info in self._models.items(): | |
| status["models"][name] = { | |
| "loaded": info.is_loaded, | |
| "critical": info.is_critical, | |
| "estimated_gb": info.estimated_memory_gb, | |
| "last_used": info.last_used | |
| } | |
| if info.is_loaded: | |
| status["total_estimated_gb"] += info.estimated_memory_gb | |
| # GPU memory | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| cached = torch.cuda.memory_reserved() / 1024**3 | |
| status["gpu"] = { | |
| "allocated_gb": round(allocated, 2), | |
| "total_gb": round(total, 2), | |
| "cached_gb": round(cached, 2), | |
| "free_gb": round(total - allocated, 2), | |
| "usage_percent": round((allocated / total) * 100, 1) | |
| } | |
| return status | |
| def get_loaded_models(self) -> list: | |
| """Get list of currently loaded model names.""" | |
| return [name for name, info in self._models.items() if info.is_loaded] | |
| def is_model_loaded(self, name: str) -> bool: | |
| """Check if a specific model is loaded.""" | |
| if name not in self._models: | |
| return False | |
| return self._models[name].is_loaded | |
| # Global singleton instance | |
| _model_manager = None | |
| def get_model_manager() -> ModelManager: | |
| """Get the global ModelManager singleton instance.""" | |
| global _model_manager | |
| if _model_manager is None: | |
| _model_manager = ModelManager() | |
| return _model_manager | |