AI_Avatar_Chat / auto_model_download.py
Developer
🎬 WORKING VIDEO GENERATION: Actually download and use models!
24574f4
raw
history blame
2.54 kB
#!/usr/bin/env python3
"""
Auto Model Download for HF Spaces
Automatically downloads video generation models on startup if storage allows
"""
import os
import logging
import asyncio
from pathlib import Path
logger = logging.getLogger(__name__)
async def auto_download_models():
"""Auto-download models on startup if possible"""
logger.info("?? Auto model download starting...")
try:
import shutil
from huggingface_hub import snapshot_download
# Check if models already exist
models_dir = Path("./downloaded_models")
if models_dir.exists() and any(models_dir.iterdir()):
logger.info("? Models already downloaded, skipping...")
return True
# Check storage
_, _, free_bytes = shutil.disk_usage(".")
free_gb = free_bytes / (1024**3)
logger.info(f"?? Storage available: {free_gb:.1f}GB")
if free_gb < 8: # Need at least 8GB for small models
logger.warning(f"?? Insufficient storage for model download: {free_gb:.1f}GB < 8GB")
return False
logger.info("?? Starting automatic model download...")
# Download small video model
logger.info("?? Downloading text-to-video model...")
video_path = snapshot_download(
repo_id="ali-vilab/text-to-video-ms-1.7b",
cache_dir="./downloaded_models/video"
)
logger.info("?? Downloading audio model...")
audio_path = snapshot_download(
repo_id="facebook/wav2vec2-base-960h",
cache_dir="./downloaded_models/audio"
)
# Create success marker
success_file = models_dir / "download_success.txt"
with open(success_file, "w") as f:
f.write(f"Models downloaded successfully\\n")
f.write(f"Video model: {video_path}\\n")
f.write(f"Audio model: {audio_path}\\n")
logger.info("? Auto model download completed successfully!")
return True
except Exception as e:
logger.error(f"? Auto model download failed: {e}")
return False
# Set environment variable to indicate auto-download attempt
os.environ["AUTO_MODEL_DOWNLOAD_ATTEMPTED"] = "1"
if __name__ == "__main__":
# Run auto download
result = asyncio.run(auto_download_models())
if result:
print("?? Models ready for video generation!")
else:
print("?? Models not downloaded - running in TTS mode")