#!/usr/bin/env python3 """ OmniAvatar-14B Setup Script Downloads all required models and sets up the proper directory structure. """ import os import subprocess import sys import logging from pathlib import Path # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class OmniAvatarSetup: def __init__(self): self.base_dir = Path.cwd() self.models_dir = self.base_dir / "pretrained_models" # Model specifications from OmniAvatar documentation self.models = { "Wan2.1-T2V-14B": { "repo": "Wan-AI/Wan2.1-T2V-14B", "description": "Base model for 14B OmniAvatar model", "size": "~28GB" }, "OmniAvatar-14B": { "repo": "OmniAvatar/OmniAvatar-14B", "description": "LoRA and audio condition weights", "size": "~2GB" }, "wav2vec2-base-960h": { "repo": "facebook/wav2vec2-base-960h", "description": "Audio encoder", "size": "~360MB" } } def check_dependencies(self): """Check if required dependencies are installed""" logger.info("🔍 Checking dependencies...") try: import torch logger.info(f"✅ PyTorch version: {torch.__version__}") if torch.cuda.is_available(): logger.info(f"✅ CUDA available: {torch.version.cuda}") logger.info(f"✅ GPU devices: {torch.cuda.device_count()}") else: logger.warning("⚠️ CUDA not available - will use CPU (slower)") except ImportError: logger.error("❌ PyTorch not installed!") return False return True def install_huggingface_cli(self): """Install huggingface CLI if not available""" try: result = subprocess.run(['huggingface-cli', '--version'], capture_output=True, text=True) if result.returncode == 0: logger.info("✅ Hugging Face CLI available") return True except FileNotFoundError: pass logger.info("📦 Installing huggingface-hub CLI...") try: subprocess.run([sys.executable, '-m', 'pip', 'install', 'huggingface_hub[cli]'], check=True) logger.info("✅ Hugging Face CLI installed") return True except subprocess.CalledProcessError as e: logger.error(f"❌ Failed to install Hugging Face CLI: {e}") return False def create_directory_structure(self): """Create the required directory structure""" logger.info("📁 Creating directory structure...") directories = [ self.models_dir, self.models_dir / "Wan2.1-T2V-14B", self.models_dir / "OmniAvatar-14B", self.models_dir / "wav2vec2-base-960h", self.base_dir / "outputs", self.base_dir / "configs", self.base_dir / "scripts", self.base_dir / "examples" ] for directory in directories: directory.mkdir(parents=True, exist_ok=True) logger.info(f"✅ Created: {directory}") def download_models(self): """Download all required models""" logger.info("🔄 Starting model downloads...") logger.info("⚠️ This will download approximately 30GB of models!") response = input("Continue with download? (y/N): ") if response.lower() != 'y': logger.info("❌ Download cancelled by user") return False for model_name, model_info in self.models.items(): logger.info(f"📥 Downloading {model_name} ({model_info['size']})...") logger.info(f"📝 {model_info['description']}") local_dir = self.models_dir / model_name # Skip if already exists and has content if local_dir.exists() and any(local_dir.iterdir()): logger.info(f"✅ {model_name} already exists, skipping...") continue try: cmd = [ 'huggingface-cli', 'download', model_info['repo'], '--local-dir', str(local_dir) ] logger.info(f"🚀 Running: {' '.join(cmd)}") result = subprocess.run(cmd, check=True) logger.info(f"✅ {model_name} downloaded successfully!") except subprocess.CalledProcessError as e: logger.error(f"❌ Failed to download {model_name}: {e}") return False logger.info("✅ All models downloaded successfully!") return True def run_setup(self): """Run the complete setup process""" logger.info("🚀 Starting OmniAvatar-14B setup...") if not self.check_dependencies(): logger.error("❌ Dependencies check failed!") return False if not self.install_huggingface_cli(): logger.error("❌ Failed to install Hugging Face CLI!") return False self.create_directory_structure() if not self.download_models(): logger.error("❌ Model download failed!") return False logger.info("🎉 OmniAvatar-14B setup completed successfully!") logger.info("💡 You can now run the full avatar generation!") return True def main(): setup = OmniAvatarSetup() setup.run_setup() if __name__ == "__main__": main()