#!/usr/bin/env python3 """ Auto Model Download for HF Spaces Automatically downloads video generation models on startup if storage allows """ import os import logging import asyncio from pathlib import Path logger = logging.getLogger(__name__) async def auto_download_models(): """Auto-download models on startup if possible""" logger.info("?? Auto model download starting...") try: import shutil from huggingface_hub import snapshot_download # Check if models already exist models_dir = Path("./downloaded_models") if models_dir.exists() and any(models_dir.iterdir()): logger.info("? Models already downloaded, skipping...") return True # Check storage _, _, free_bytes = shutil.disk_usage(".") free_gb = free_bytes / (1024**3) logger.info(f"?? Storage available: {free_gb:.1f}GB") if free_gb < 8: # Need at least 8GB for small models logger.warning(f"?? Insufficient storage for model download: {free_gb:.1f}GB < 8GB") return False logger.info("?? Starting automatic model download...") # Download small video model logger.info("?? Downloading text-to-video model...") video_path = snapshot_download( repo_id="ali-vilab/text-to-video-ms-1.7b", cache_dir="./downloaded_models/video" ) logger.info("?? Downloading audio model...") audio_path = snapshot_download( repo_id="facebook/wav2vec2-base-960h", cache_dir="./downloaded_models/audio" ) # Create success marker success_file = models_dir / "download_success.txt" with open(success_file, "w") as f: f.write(f"Models downloaded successfully\\n") f.write(f"Video model: {video_path}\\n") f.write(f"Audio model: {audio_path}\\n") logger.info("? Auto model download completed successfully!") return True except Exception as e: logger.error(f"? Auto model download failed: {e}") return False # Set environment variable to indicate auto-download attempt os.environ["AUTO_MODEL_DOWNLOAD_ATTEMPTED"] = "1" if __name__ == "__main__": # Run auto download result = asyncio.run(auto_download_models()) if result: print("?? Models ready for video generation!") else: print("?? Models not downloaded - running in TTS mode")