import logging import gc import time from typing import Dict, Any, Optional, Callable from dataclasses import dataclass, field from threading import Lock import torch logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @dataclass class ModelInfo: """Information about a registered model.""" name: str loader: Callable[[], Any] is_critical: bool = False # Critical models are not unloaded under memory pressure estimated_memory_gb: float = 0.0 is_loaded: bool = False last_used: float = 0.0 model_instance: Any = None class ModelManager: """ Singleton model manager for unified model lifecycle management. Handles lazy loading, caching, and intelligent memory management. """ _instance = None _lock = Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self._models: Dict[str, ModelInfo] = {} self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage self._device = self._detect_device() logger.info(f"๐Ÿ”ง ModelManager initialized on {self._device}") self._initialized = True def _detect_device(self) -> str: """Detect best available device.""" if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" return "cpu" def register_model( self, name: str, loader: Callable[[], Any], is_critical: bool = False, estimated_memory_gb: float = 0.0 ): """ Register a model for managed loading. Args: name: Unique model identifier loader: Callable that returns the loaded model is_critical: If True, model won't be unloaded under memory pressure estimated_memory_gb: Estimated GPU memory usage in GB """ if name in self._models: logger.warning(f"โš ๏ธ Model '{name}' already registered, updating") self._models[name] = ModelInfo( name=name, loader=loader, is_critical=is_critical, estimated_memory_gb=estimated_memory_gb, is_loaded=False, last_used=0.0, model_instance=None ) logger.info(f"๐Ÿ“ Registered model: {name} (critical={is_critical}, ~{estimated_memory_gb:.1f}GB)") def load_model(self, name: str) -> Any: """ Load a model by name. Returns cached instance if already loaded. Args: name: Model identifier Returns: Loaded model instance Raises: KeyError: If model not registered RuntimeError: If loading fails """ if name not in self._models: raise KeyError(f"Model '{name}' not registered") model_info = self._models[name] # Return cached instance if model_info.is_loaded and model_info.model_instance is not None: model_info.last_used = time.time() logger.debug(f"๐Ÿ“ฆ Using cached model: {name}") return model_info.model_instance # Check memory pressure before loading self.check_memory_pressure() # Load the model try: logger.info(f"๐Ÿ“ฅ Loading model: {name}") start_time = time.time() model_instance = model_info.loader() model_info.model_instance = model_instance model_info.is_loaded = True model_info.last_used = time.time() load_time = time.time() - start_time logger.info(f"โœ… Model '{name}' loaded in {load_time:.1f}s") return model_instance except Exception as e: logger.error(f"โŒ Failed to load model '{name}': {e}") raise RuntimeError(f"Model loading failed: {e}") def unload_model(self, name: str): """ Unload a specific model to free memory. Args: name: Model identifier """ if name not in self._models: return model_info = self._models[name] if not model_info.is_loaded: return try: logger.info(f"๐Ÿ—‘๏ธ Unloading model: {name}") # Delete model instance if model_info.model_instance is not None: del model_info.model_instance model_info.model_instance = None model_info.is_loaded = False # Cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"โœ… Model '{name}' unloaded") except Exception as e: logger.error(f"โŒ Error unloading model '{name}': {e}") def check_memory_pressure(self) -> bool: """ Check GPU memory usage and unload least-used non-critical models if needed. Returns: True if cleanup was performed """ if not torch.cuda.is_available(): return False allocated = torch.cuda.memory_allocated() / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 usage_ratio = allocated / total if usage_ratio < self._memory_threshold: return False logger.warning(f"โš ๏ธ Memory pressure detected: {usage_ratio:.1%} used") # Find non-critical models sorted by last used time unloadable = [ (name, info) for name, info in self._models.items() if info.is_loaded and not info.is_critical ] unloadable.sort(key=lambda x: x[1].last_used) # Unload oldest non-critical models cleaned = False for name, info in unloadable: self.unload_model(name) cleaned = True # Re-check memory new_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory if new_ratio < self._memory_threshold * 0.7: # Target 70% of threshold break return cleaned def force_cleanup(self): """Force cleanup all non-critical models and clear caches.""" logger.info("๐Ÿงน Force cleanup initiated") # Unload all non-critical models for name, info in self._models.items(): if info.is_loaded and not info.is_critical: self.unload_model(name) # Aggressive garbage collection for _ in range(5): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() torch.cuda.synchronize() logger.info("โœ… Force cleanup completed") def get_memory_status(self) -> Dict[str, Any]: """ Get detailed memory status. Returns: Dictionary with memory statistics """ status = { "device": self._device, "models": {}, "total_estimated_gb": 0.0 } # Model status for name, info in self._models.items(): status["models"][name] = { "loaded": info.is_loaded, "critical": info.is_critical, "estimated_gb": info.estimated_memory_gb, "last_used": info.last_used } if info.is_loaded: status["total_estimated_gb"] += info.estimated_memory_gb # GPU memory if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 cached = torch.cuda.memory_reserved() / 1024**3 status["gpu"] = { "allocated_gb": round(allocated, 2), "total_gb": round(total, 2), "cached_gb": round(cached, 2), "free_gb": round(total - allocated, 2), "usage_percent": round((allocated / total) * 100, 1) } return status def get_loaded_models(self) -> list: """Get list of currently loaded model names.""" return [name for name, info in self._models.items() if info.is_loaded] def is_model_loaded(self, name: str) -> bool: """Check if a specific model is loaded.""" if name not in self._models: return False return self._models[name].is_loaded # Global singleton instance _model_manager = None def get_model_manager() -> ModelManager: """Get the global ModelManager singleton instance.""" global _model_manager if _model_manager is None: _model_manager = ModelManager() return _model_manager