Spaces:
Build error
Build error
| from __future__ import annotations | |
| import datetime | |
| import os | |
| import pathlib | |
| import shlex | |
| import shutil | |
| import subprocess | |
| import sys | |
| import gradio as gr | |
| import slugify | |
| import torch | |
| import huggingface_hub | |
| from huggingface_hub import HfApi | |
| from omegaconf import OmegaConf | |
| ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero' | |
| SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID) | |
| class Runner: | |
| def __init__(self, hf_token: str | None = None): | |
| self.hf_token = hf_token | |
| self.checkpoint_dir = pathlib.Path('checkpoints') | |
| self.checkpoint_dir.mkdir(exist_ok=True) | |
| def download_base_model(self, base_model_id: str, token=None) -> str: | |
| model_dir = self.checkpoint_dir / base_model_id | |
| org_name = base_model_id.split('/')[0] | |
| org_dir = self.checkpoint_dir / org_name | |
| if not model_dir.exists(): | |
| org_dir.mkdir(exist_ok=True) | |
| print(f'https://huggingface.co/{base_model_id}') | |
| if token == None: | |
| subprocess.run(shlex.split(f'git lfs install'), cwd=org_dir) | |
| subprocess.run(shlex.split( | |
| f'git lfs clone https://huggingface.co/{base_model_id}'), | |
| cwd=org_dir) | |
| return model_dir.as_posix() | |
| else: | |
| temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token) | |
| print(temp_path, org_dir) | |
| # subprocess.run(shlex.split(f'mv {temp_path} {model_dir.as_posix()}')) | |
| # return model_dir.as_posix() | |
| return temp_path | |
| def join_model_library_org(self, token: str) -> None: | |
| subprocess.run( | |
| shlex.split( | |
| f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}' | |
| )) | |
| def run_vid2vid_zero( | |
| self, | |
| model_path: str, | |
| input_video: str, | |
| prompt: str, | |
| n_sample_frames: int, | |
| sample_start_idx: int, | |
| sample_frame_rate: int, | |
| validation_prompt: str, | |
| guidance_scale: float, | |
| resolution: str, | |
| seed: int, | |
| remove_gpu_after_running: bool, | |
| input_token: str = None, | |
| ) -> str: | |
| if not torch.cuda.is_available(): | |
| raise gr.Error('CUDA is not available.') | |
| if input_video is None: | |
| raise gr.Error('You need to upload a video.') | |
| if not prompt: | |
| raise gr.Error('The input prompt is missing.') | |
| if not validation_prompt: | |
| raise gr.Error('The validation prompt is missing.') | |
| resolution = int(resolution) | |
| n_sample_frames = int(n_sample_frames) | |
| sample_start_idx = int(sample_start_idx) | |
| sample_frame_rate = int(sample_frame_rate) | |
| repo_dir = pathlib.Path(__file__).parent | |
| prompt_path = prompt.replace(' ', '_') | |
| output_dir = repo_dir / 'outputs' / prompt_path | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| config = OmegaConf.load('configs/black-swan.yaml') | |
| config.pretrained_model_path = self.download_base_model(model_path, token=input_token) | |
| # we remove null-inversion & use fp16 for fast inference on web demo | |
| config.mixed_precision = "fp16" | |
| config.validation_data.use_null_inv = False | |
| config.output_dir = output_dir.as_posix() | |
| config.input_data.video_path = input_video.name # type: ignore | |
| config.input_data.prompt = prompt | |
| config.input_data.n_sample_frames = n_sample_frames | |
| config.input_data.width = resolution | |
| config.input_data.height = resolution | |
| config.input_data.sample_start_idx = sample_start_idx | |
| config.input_data.sample_frame_rate = sample_frame_rate | |
| config.validation_data.prompts = [validation_prompt] | |
| config.validation_data.video_length = 8 | |
| config.validation_data.width = resolution | |
| config.validation_data.height = resolution | |
| config.validation_data.num_inference_steps = 50 | |
| config.validation_data.guidance_scale = guidance_scale | |
| config.input_batch_size = 1 | |
| config.seed = seed | |
| config_path = output_dir / 'config.yaml' | |
| with open(config_path, 'w') as f: | |
| OmegaConf.save(config, f) | |
| command = f'accelerate launch test_vid2vid_zero.py --config {config_path}' | |
| subprocess.run(shlex.split(command)) | |
| output_video_path = os.path.join(output_dir, "sample-all.mp4") | |
| print(f"video path for gradio: {output_video_path}") | |
| message = 'Running completed!' | |
| print(message) | |
| if remove_gpu_after_running: | |
| space_id = os.getenv('SPACE_ID') | |
| if space_id: | |
| api = HfApi( | |
| token=self.hf_token if self.hf_token else input_token) | |
| api.request_space_hardware(repo_id=space_id, | |
| hardware='cpu-basic') | |
| return output_video_path | |