Spaces:
Runtime error
Runtime error
CSH-1220
commited on
Commit
Β·
24363dc
1
Parent(s):
1834911
Update how we load pre-trained weights
Browse files- app.py +0 -15
- audio_encoder/AudioMAE.py +6 -1
- pipeline/morph_pipeline_successed_ver1.py +7 -3
app.py
CHANGED
|
@@ -3,21 +3,6 @@ import torch
|
|
| 3 |
import torchaudio
|
| 4 |
import numpy as np
|
| 5 |
import gradio as gr
|
| 6 |
-
from huggingface_hub import hf_hub_download
|
| 7 |
-
model_path = hf_hub_download(
|
| 8 |
-
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
| 9 |
-
filename="pretrained.pth",
|
| 10 |
-
local_dir="./",
|
| 11 |
-
local_dir_use_symlinks=False
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
model_path = hf_hub_download(
|
| 15 |
-
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
| 16 |
-
filename="pytorch_model.bin",
|
| 17 |
-
local_dir="./",
|
| 18 |
-
local_dir_use_symlinks=False
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
|
| 22 |
# Initialize AudioLDM2 Pipeline
|
| 23 |
pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
|
|
|
|
| 3 |
import torchaudio
|
| 4 |
import numpy as np
|
| 5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
|
| 7 |
# Initialize AudioLDM2 Pipeline
|
| 8 |
pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
|
audio_encoder/AudioMAE.py
CHANGED
|
@@ -12,6 +12,7 @@ import librosa.display
|
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
import numpy as np
|
| 14 |
import torchaudio
|
|
|
|
| 15 |
|
| 16 |
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
|
| 17 |
class Vanilla_AudioMAE(nn.Module):
|
|
@@ -25,7 +26,11 @@ class Vanilla_AudioMAE(nn.Module):
|
|
| 25 |
in_chans=1, audio_exp=True, img_size=(1024, 128)
|
| 26 |
)
|
| 27 |
|
| 28 |
-
checkpoint_path = 'pretrained.pth'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 30 |
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
| 31 |
|
|
|
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
import numpy as np
|
| 14 |
import torchaudio
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
|
| 17 |
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
|
| 18 |
class Vanilla_AudioMAE(nn.Module):
|
|
|
|
| 26 |
in_chans=1, audio_exp=True, img_size=(1024, 128)
|
| 27 |
)
|
| 28 |
|
| 29 |
+
# checkpoint_path = 'pretrained.pth'
|
| 30 |
+
checkpoint_path = hf_hub_download(
|
| 31 |
+
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
| 32 |
+
filename="pretrained.pth"
|
| 33 |
+
)
|
| 34 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 35 |
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
| 36 |
|
pipeline/morph_pipeline_successed_ver1.py
CHANGED
|
@@ -49,8 +49,7 @@ if is_librosa_available():
|
|
| 49 |
import librosa
|
| 50 |
import warnings
|
| 51 |
import matplotlib.pyplot as plt
|
| 52 |
-
|
| 53 |
-
|
| 54 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
| 55 |
|
| 56 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -91,7 +90,12 @@ for name in unet.attn_processors.keys():
|
|
| 91 |
else:
|
| 92 |
attn_procs[name] = AttnProcessor2_0()
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
for name, processor in attn_procs.items():
|
| 96 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
| 97 |
weight_name_v = name + ".to_v_ip.weight"
|
|
|
|
| 49 |
import librosa
|
| 50 |
import warnings
|
| 51 |
import matplotlib.pyplot as plt
|
| 52 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
| 53 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
| 54 |
|
| 55 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 90 |
else:
|
| 91 |
attn_procs[name] = AttnProcessor2_0()
|
| 92 |
|
| 93 |
+
adapter_weight = hf_hub_download(
|
| 94 |
+
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
| 95 |
+
filename="pytorch_model.bin",
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
state_dict = torch.load(adapter_weight, map_location=DEVICE)
|
| 99 |
for name, processor in attn_procs.items():
|
| 100 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
| 101 |
weight_name_v = name + ".to_v_ip.weight"
|