Spaces:
Runtime error
Runtime error
CSH-1220
commited on
Commit
·
55f08a9
1
Parent(s):
aef267d
Files update
Browse files- app.py +5 -1
- download.py +9 -0
- pipeline/morph_pipeline_successed_ver1.py +101 -175
- utils/lora_utils_successed_ver1.py +12 -24
app.py
CHANGED
|
@@ -47,9 +47,13 @@ def morph_audio(audio_file1, audio_file2, prompt1, prompt2, negative_prompt1="Lo
|
|
| 47 |
)
|
| 48 |
|
| 49 |
# Collect the output file paths
|
| 50 |
-
output_paths =
|
|
|
|
|
|
|
|
|
|
| 51 |
return output_paths
|
| 52 |
|
|
|
|
| 53 |
# Gradio interface function
|
| 54 |
def interface(audio1, audio2, prompt1, prompt2):
|
| 55 |
output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
|
|
|
|
| 47 |
)
|
| 48 |
|
| 49 |
# Collect the output file paths
|
| 50 |
+
output_paths = sorted(
|
| 51 |
+
[os.path.join(save_lora_dir, file) for file in os.listdir(save_lora_dir) if file.endswith(".wav")],
|
| 52 |
+
key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
|
| 53 |
+
)
|
| 54 |
return output_paths
|
| 55 |
|
| 56 |
+
|
| 57 |
# Gradio interface function
|
| 58 |
def interface(audio1, audio2, prompt1, prompt2):
|
| 59 |
output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
|
download.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import hf_hub_download
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
model_path = hf_hub_download(
|
| 5 |
+
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
| 6 |
+
filename="pytorch_model.bin",
|
| 7 |
+
local_dir="./",
|
| 8 |
+
local_dir_use_symlinks=False
|
| 9 |
+
)
|
pipeline/morph_pipeline_successed_ver1.py
CHANGED
|
@@ -49,64 +49,12 @@ if is_librosa_available():
|
|
| 49 |
import librosa
|
| 50 |
import warnings
|
| 51 |
import matplotlib.pyplot as plt
|
| 52 |
-
from huggingface_hub import hf_hub_download
|
| 53 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
| 54 |
|
| 55 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 56 |
-
|
| 57 |
-
pipeline_trained = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
|
| 58 |
-
pipeline_trained = pipeline_trained.to(DEVICE)
|
| 59 |
-
layer_num = 0
|
| 60 |
-
cross = [None, None, 768, 768, 1024, 1024, None, None]
|
| 61 |
-
unet = pipeline_trained.unet
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
attn_procs = {}
|
| 65 |
-
for name in unet.attn_processors.keys():
|
| 66 |
-
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 67 |
-
if name.startswith("mid_block"):
|
| 68 |
-
hidden_size = unet.config.block_out_channels[-1]
|
| 69 |
-
elif name.startswith("up_blocks"):
|
| 70 |
-
block_id = int(name[len("up_blocks.")])
|
| 71 |
-
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 72 |
-
elif name.startswith("down_blocks"):
|
| 73 |
-
block_id = int(name[len("down_blocks.")])
|
| 74 |
-
hidden_size = unet.config.block_out_channels[block_id]
|
| 75 |
-
|
| 76 |
-
if cross_attention_dim is None:
|
| 77 |
-
attn_procs[name] = AttnProcessor2_0()
|
| 78 |
-
else:
|
| 79 |
-
cross_attention_dim = cross[layer_num % 8]
|
| 80 |
-
layer_num += 1
|
| 81 |
-
if cross_attention_dim == 768:
|
| 82 |
-
attn_procs[name] = IPAttnProcessor2_0(
|
| 83 |
-
hidden_size=hidden_size,
|
| 84 |
-
name=name,
|
| 85 |
-
cross_attention_dim=cross_attention_dim,
|
| 86 |
-
scale=0.5,
|
| 87 |
-
num_tokens=8,
|
| 88 |
-
do_copy=False
|
| 89 |
-
).to(DEVICE, dtype=torch.float32)
|
| 90 |
-
else:
|
| 91 |
-
attn_procs[name] = AttnProcessor2_0()
|
| 92 |
-
|
| 93 |
-
adapter_weight = hf_hub_download(
|
| 94 |
-
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
| 95 |
-
filename="pytorch_model.bin",
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
state_dict = torch.load(adapter_weight, map_location=DEVICE)
|
| 99 |
-
for name, processor in attn_procs.items():
|
| 100 |
-
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
| 101 |
-
weight_name_v = name + ".to_v_ip.weight"
|
| 102 |
-
weight_name_k = name + ".to_k_ip.weight"
|
| 103 |
-
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
|
| 104 |
-
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
|
| 105 |
-
|
| 106 |
-
unet.set_attn_processor(attn_procs)
|
| 107 |
-
unet.to(DEVICE, dtype=torch.float32)
|
| 108 |
-
|
| 109 |
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
|
|
@@ -125,10 +73,6 @@ def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
|
|
| 125 |
plt.show()
|
| 126 |
|
| 127 |
|
| 128 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 129 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 130 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 131 |
-
|
| 132 |
class StoreProcessor():
|
| 133 |
def __init__(self, original_processor, value_dict, name):
|
| 134 |
self.original_processor = original_processor
|
|
@@ -140,12 +84,9 @@ class StoreProcessor():
|
|
| 140 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
| 141 |
# Is self attention
|
| 142 |
if encoder_hidden_states is None:
|
| 143 |
-
# 將 hidden_states 存入 value_dict 中,名稱為 self.name
|
| 144 |
-
# 如果輸入沒有 encoder_hidden_states,表示是自注意力層,則將輸入的 hidden_states 儲存在 value_dict 中。
|
| 145 |
# print(f'In StoreProcessor: {self.name} {self.id}')
|
| 146 |
self.value_dict[self.name][self.id] = hidden_states.detach()
|
| 147 |
self.id += 1
|
| 148 |
-
# 調用原始處理器,執行正常的注意力操作
|
| 149 |
res = self.original_processor(attn, hidden_states, *args,
|
| 150 |
encoder_hidden_states=encoder_hidden_states,
|
| 151 |
attention_mask=attention_mask,
|
|
@@ -167,32 +108,26 @@ class LoadProcessor():
|
|
| 167 |
|
| 168 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
| 169 |
# Is self attention
|
| 170 |
-
# 判斷是否是自注意力(self-attention)
|
| 171 |
if encoder_hidden_states is None:
|
| 172 |
-
# 如果當前索引小於 10 倍的 self.lamd,使用自定義的混合邏輯
|
| 173 |
if self.id < 10 * self.lamd:
|
| 174 |
map0 = self.aud1_dict[self.name][self.id]
|
| 175 |
map1 = self.aud2_dict[self.name][self.id]
|
| 176 |
cross_map = self.beta * hidden_states + \
|
| 177 |
(1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
|
| 178 |
-
# 調用原始處理器,將 cross_map 作為 encoder_hidden_states 傳入
|
| 179 |
res = self.original_processor(attn, hidden_states, *args,
|
| 180 |
encoder_hidden_states=cross_map,
|
| 181 |
attention_mask=attention_mask,
|
| 182 |
**kwargs)
|
| 183 |
else:
|
| 184 |
-
# 否則,使用原��的 encoder_hidden_states(可能為 None)
|
| 185 |
res = self.original_processor(attn, hidden_states, *args,
|
| 186 |
encoder_hidden_states=encoder_hidden_states,
|
| 187 |
attention_mask=attention_mask,
|
| 188 |
**kwargs)
|
| 189 |
|
| 190 |
self.id += 1
|
| 191 |
-
# 如果索引到達 self.aud1_dict[self.name] 的長度,重置索引為 0
|
| 192 |
if self.id == len(self.aud1_dict[self.name]):
|
| 193 |
self.id = 0
|
| 194 |
else:
|
| 195 |
-
# 如果是跨注意力(encoder_hidden_states 不為 None),直接使用原始處理器
|
| 196 |
res = self.original_processor(attn, hidden_states, *args,
|
| 197 |
encoder_hidden_states=encoder_hidden_states,
|
| 198 |
attention_mask=attention_mask,
|
|
@@ -908,7 +843,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 908 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
| 909 |
# print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
|
| 910 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
| 911 |
-
model = AudioMAEConditionCTPoolRand().
|
| 912 |
model.eval()
|
| 913 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
| 914 |
uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
|
|
@@ -932,16 +867,66 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 932 |
|
| 933 |
return prompt_embeds, attention_mask, generated_prompt_embeds
|
| 934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
@torch.no_grad()
|
| 936 |
def aud2latent(self, audio_path, audio_length_in_s):
|
| 937 |
DEVICE = torch.device(
|
| 938 |
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 939 |
-
|
| 940 |
-
# waveform, sr = torchaudio.load(audio_path)
|
| 941 |
-
# fbank = torch.zeros((height, 64))
|
| 942 |
-
# ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank, num_mels=64)
|
| 943 |
-
# mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0).unsqueeze(0)
|
| 944 |
-
|
| 945 |
mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
|
| 946 |
output_path = audio_path.replace('.wav', '_fbank.png')
|
| 947 |
visualize_mel_spectrogram(mel_spect_tensor, output_path)
|
|
@@ -954,7 +939,8 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 954 |
@torch.no_grad()
|
| 955 |
def ddim_inversion(self, start_latents, prompt_embeds, attention_mask, generated_prompt_embeds, guidance_scale,num_inference_steps):
|
| 956 |
start_step = 0
|
| 957 |
-
|
|
|
|
| 958 |
device = start_latents.device
|
| 959 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 960 |
start_latents *= self.scheduler.init_noise_sigma
|
|
@@ -973,9 +959,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 973 |
def generate_morphing_prompt(self, prompt_1, prompt_2, alpha):
|
| 974 |
closer_prompt = prompt_1 if alpha <= 0.5 else prompt_2
|
| 975 |
prompt = (
|
| 976 |
-
f"
|
| 977 |
-
f"The sound is closer to '{closer_prompt}' with an interpolation factor of alpha={alpha:.2f}, "
|
| 978 |
-
f"where alpha=0 represents fully the {prompt_1} and alpha=1 represents fully {prompt_2}."
|
| 979 |
)
|
| 980 |
return prompt
|
| 981 |
|
|
@@ -983,8 +967,10 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 983 |
def cal_latent(self,audio_length_in_s,time_pooling, freq_pooling,num_inference_steps, guidance_scale, aud_noise_1, aud_noise_2, prompt_1, prompt_2,
|
| 984 |
prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2,
|
| 985 |
alpha, original_processor,attn_processor_dict, use_morph_prompt, morphing_with_lora):
|
|
|
|
| 986 |
latents = slerp(aud_noise_1, aud_noise_2, alpha, self.use_adain)
|
| 987 |
if not use_morph_prompt:
|
|
|
|
| 988 |
max_length = max(prompt_embeds_1.shape[1], prompt_embeds_2.shape[1])
|
| 989 |
if prompt_embeds_1.shape[1] < max_length:
|
| 990 |
pad_size = max_length - prompt_embeds_1.shape[1]
|
|
@@ -1033,13 +1019,13 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1033 |
# attention_mask = (attention_mask > 0.5).long()
|
| 1034 |
|
| 1035 |
if morphing_with_lora:
|
| 1036 |
-
pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
| 1037 |
-
waveform = pipeline_trained(
|
| 1038 |
time_pooling= time_pooling,
|
| 1039 |
freq_pooling= freq_pooling,
|
| 1040 |
latents = latents,
|
| 1041 |
num_inference_steps= num_inference_steps,
|
| 1042 |
-
guidance_scale= guidance_scale,
|
| 1043 |
num_waveforms_per_prompt= 1,
|
| 1044 |
audio_length_in_s=audio_length_in_s,
|
| 1045 |
prompt_embeds = prompt_embeds.chunk(2)[1],
|
|
@@ -1050,13 +1036,13 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1050 |
negative_attention_mask = attention_mask.chunk(2)[0],
|
| 1051 |
).audios[0]
|
| 1052 |
if morphing_with_lora:
|
| 1053 |
-
pipeline_trained.unet.set_attn_processor(original_processor)
|
| 1054 |
else:
|
| 1055 |
latent_model_input = latents
|
| 1056 |
morphing_prompt = self.generate_morphing_prompt(prompt_1, prompt_2, alpha)
|
| 1057 |
if morphing_with_lora:
|
| 1058 |
-
pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
| 1059 |
-
waveform = pipeline_trained(
|
| 1060 |
time_pooling= time_pooling,
|
| 1061 |
freq_pooling= freq_pooling,
|
| 1062 |
latents = latent_model_input,
|
|
@@ -1068,15 +1054,18 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1068 |
negative_prompt= 'Low quality',
|
| 1069 |
).audios[0]
|
| 1070 |
if morphing_with_lora:
|
| 1071 |
-
pipeline_trained.unet.set_attn_processor(original_processor)
|
| 1072 |
|
| 1073 |
-
return waveform
|
| 1074 |
|
| 1075 |
@torch.no_grad()
|
| 1076 |
def __call__(
|
| 1077 |
self,
|
|
|
|
| 1078 |
audio_file = None,
|
| 1079 |
audio_file2 = None,
|
|
|
|
|
|
|
| 1080 |
save_lora_dir = "./lora",
|
| 1081 |
load_lora_path_1 = None,
|
| 1082 |
load_lora_path_2 = None,
|
|
@@ -1100,7 +1089,6 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1100 |
attn_beta=0,
|
| 1101 |
lamd=0.6,
|
| 1102 |
fix_lora=None,
|
| 1103 |
-
save_intermediates=True,
|
| 1104 |
num_frames=50,
|
| 1105 |
max_new_tokens: Optional[int] = None,
|
| 1106 |
callback_steps: Optional[int] = 1,
|
|
@@ -1108,6 +1096,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1108 |
morphing_with_lora=False,
|
| 1109 |
use_morph_prompt=False,
|
| 1110 |
):
|
|
|
|
| 1111 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1112 |
# 0. Load the pre-trained AP-adapter model
|
| 1113 |
layer_num = 0
|
|
@@ -1123,48 +1112,44 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1123 |
elif name.startswith("down_blocks"):
|
| 1124 |
block_id = int(name[len("down_blocks.")])
|
| 1125 |
hidden_size = self.unet.config.block_out_channels[block_id]
|
| 1126 |
-
|
| 1127 |
if cross_attention_dim is None:
|
| 1128 |
attn_procs[name] = AttnProcessor2_0()
|
| 1129 |
else:
|
| 1130 |
cross_attention_dim = cross[layer_num % 8]
|
| 1131 |
layer_num += 1
|
| 1132 |
if cross_attention_dim == 768:
|
| 1133 |
-
attn_procs[name] = IPAttnProcessor2_0(
|
| 1134 |
hidden_size=hidden_size,
|
| 1135 |
name=name,
|
| 1136 |
cross_attention_dim=cross_attention_dim,
|
| 1137 |
-
|
|
|
|
| 1138 |
num_tokens=8,
|
| 1139 |
do_copy=False
|
| 1140 |
-
).to(
|
| 1141 |
else:
|
| 1142 |
attn_procs[name] = AttnProcessor2_0()
|
| 1143 |
-
|
| 1144 |
-
state_dict = torch.load(adapter_weight, map_location=device)
|
| 1145 |
for name, processor in attn_procs.items():
|
| 1146 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
| 1147 |
weight_name_v = name + ".to_v_ip.weight"
|
| 1148 |
weight_name_k = name + ".to_k_ip.weight"
|
| 1149 |
-
|
| 1150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
self.unet.set_attn_processor(attn_procs)
|
| 1152 |
-
self.
|
| 1153 |
-
self.unet = self.unet.to(DEVICE, dtype=torch.float32)
|
| 1154 |
-
self.language_model = self.language_model.to(DEVICE, dtype=torch.float32)
|
| 1155 |
-
self.projection_model = self.projection_model.to(DEVICE, dtype=torch.float32)
|
| 1156 |
-
self.vocoder = self.vocoder.to(DEVICE, dtype=torch.float32)
|
| 1157 |
-
self.text_encoder = self.text_encoder.to(DEVICE, dtype=torch.float32)
|
| 1158 |
-
self.text_encoder_2 = self.text_encoder_2.to(DEVICE, dtype=torch.float32)
|
| 1159 |
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
# 1. Pre-check
|
| 1163 |
height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
|
| 1164 |
_, _ = self.pre_check(audio_length_in_s, prompt_2, callback_steps, negative_prompt_2)
|
| 1165 |
# print(f"height: {height}, original_waveform_length: {original_waveform_length}") # height: 1000, original_waveform_length: 160000
|
| 1166 |
|
| 1167 |
# # 2. Define call parameters
|
|
|
|
| 1168 |
do_classifier_free_guidance = guidance_scale > 1.0
|
| 1169 |
self.use_lora = use_lora
|
| 1170 |
self.use_adain = use_adain
|
|
@@ -1178,7 +1163,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1178 |
weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
|
| 1179 |
load_lora_path_1 = save_lora_dir + "/" + weight_name
|
| 1180 |
if not os.path.exists(load_lora_path_1):
|
| 1181 |
-
train_lora(audio_file ,
|
| 1182 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
| 1183 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
| 1184 |
print(f"Load from {load_lora_path_1}.")
|
|
@@ -1193,7 +1178,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1193 |
weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
|
| 1194 |
load_lora_path_2 = save_lora_dir + "/" + weight_name
|
| 1195 |
if not os.path.exists(load_lora_path_2):
|
| 1196 |
-
train_lora(audio_file2 ,
|
| 1197 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
| 1198 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
| 1199 |
print(f"Load from {load_lora_path_2}.")
|
|
@@ -1212,75 +1197,29 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1212 |
|
| 1213 |
|
| 1214 |
# 4. Prepare latent variables
|
| 1215 |
-
# For the first audio file
|
| 1216 |
original_processor = list(self.unet.attn_processors.values())[0]
|
| 1217 |
-
|
| 1218 |
if noisy_latent_with_lora:
|
| 1219 |
self.unet = load_lora(self.unet, lora_1, lora_2, 0)
|
| 1220 |
-
# print(self.unet.attn_processors)
|
| 1221 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
| 1222 |
audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
|
| 1223 |
-
# mel_spectrogram = self.vae.decode(audio_latent).sample
|
| 1224 |
-
# first_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
| 1225 |
-
# first_audio = first_audio[:, :original_waveform_length]
|
| 1226 |
-
# torchaudio.save(f"{self.output_path}/{0:02d}_gt.wav", first_audio, 16000)
|
| 1227 |
-
|
| 1228 |
# aud_noise_1 is the noisy latent representation of the audio file 1
|
| 1229 |
-
aud_noise_1 = self.ddim_inversion(audio_latent, prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, guidance_scale, num_inference_steps)
|
| 1230 |
-
# We use the pre-trained model to generate the audio file from the noisy latent representation
|
| 1231 |
-
# waveform = pipeline_trained(
|
| 1232 |
-
# audio_file = audio_file,
|
| 1233 |
-
# time_pooling= 2,
|
| 1234 |
-
# freq_pooling= 2,
|
| 1235 |
-
# prompt= prompt_1,
|
| 1236 |
-
# latents = aud_noise_1,
|
| 1237 |
-
# negative_prompt= negative_prompt_1,
|
| 1238 |
-
# num_inference_steps= 100,
|
| 1239 |
-
# guidance_scale= guidance_scale,
|
| 1240 |
-
# num_waveforms_per_prompt= 1,
|
| 1241 |
-
# audio_length_in_s=10,
|
| 1242 |
-
# ).audios
|
| 1243 |
-
# file_path = os.path.join(self.output_path, f"{0:02d}_gt2.wav")
|
| 1244 |
-
# scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
|
| 1245 |
-
|
| 1246 |
# After reconstructed the audio file 1, we set the original processor back
|
| 1247 |
if noisy_latent_with_lora:
|
| 1248 |
self.unet.set_attn_processor(original_processor)
|
| 1249 |
-
# print(self.unet.attn_processors)
|
| 1250 |
|
| 1251 |
-
# For the second audio file
|
| 1252 |
if noisy_latent_with_lora:
|
| 1253 |
self.unet = load_lora(self.unet, lora_1, lora_2, 1)
|
| 1254 |
-
# print(self.unet.attn_processors)
|
| 1255 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
| 1256 |
audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
|
| 1257 |
-
# mel_spectrogram = self.vae.decode(audio_latent).sample
|
| 1258 |
-
# last_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
| 1259 |
-
# last_audio = last_audio[:, :original_waveform_length]
|
| 1260 |
-
# torchaudio.save(f"{self.output_path}/{num_frames-1:02d}_gt.wav", last_audio, 16000)
|
| 1261 |
# aud_noise_2 is the noisy latent representation of the audio file 2
|
| 1262 |
-
aud_noise_2 = self.ddim_inversion(audio_latent, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2, guidance_scale, num_inference_steps)
|
| 1263 |
-
# waveform = pipeline_trained(
|
| 1264 |
-
# audio_file = audio_file2,
|
| 1265 |
-
# time_pooling= 2,
|
| 1266 |
-
# freq_pooling= 2,
|
| 1267 |
-
# prompt= prompt_2,
|
| 1268 |
-
# latents = aud_noise_2,
|
| 1269 |
-
# negative_prompt= negative_prompt_2,
|
| 1270 |
-
# num_inference_steps= 100,
|
| 1271 |
-
# guidance_scale= guidance_scale,
|
| 1272 |
-
# num_waveforms_per_prompt= 1,
|
| 1273 |
-
# audio_length_in_s=10,
|
| 1274 |
-
# ).audios
|
| 1275 |
-
# file_path = os.path.join(self.output_path, f"{num_frames-1:02d}_gt2.wav")
|
| 1276 |
-
# scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
|
| 1277 |
if noisy_latent_with_lora:
|
| 1278 |
self.unet.set_attn_processor(original_processor)
|
| 1279 |
-
# print(self.unet.attn_processors)
|
| 1280 |
# After reconstructed the audio file 1, we set the original processor back
|
| 1281 |
original_processor = list(self.unet.attn_processors.values())[0]
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
def morph(alpha_list, desc):
|
| 1285 |
audios = []
|
| 1286 |
# if attn_beta is not None:
|
|
@@ -1288,11 +1227,9 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1288 |
self.unet = load_lora(
|
| 1289 |
self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
|
| 1290 |
attn_processor_dict = {}
|
| 1291 |
-
# print(self.unet.attn_processors)
|
| 1292 |
for k in self.unet.attn_processors.keys():
|
| 1293 |
# print(k)
|
| 1294 |
if do_replace_attn(k):
|
| 1295 |
-
# print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
|
| 1296 |
if self.use_lora:
|
| 1297 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
| 1298 |
self.aud1_dict, k)
|
|
@@ -1300,16 +1237,8 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1300 |
attn_processor_dict[k] = StoreProcessor(original_processor,
|
| 1301 |
self.aud1_dict, k)
|
| 1302 |
else:
|
| 1303 |
-
attn_processor_dict[k] = self.unet.attn_processors[k]
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
# print(attn_processor_dict)
|
| 1307 |
-
|
| 1308 |
-
# print(self.unet.attn_processors)
|
| 1309 |
-
# self.unet.set_attn_processor(attn_processor_dict)
|
| 1310 |
-
# print(self.unet.attn_processors)
|
| 1311 |
-
|
| 1312 |
-
first_audio = self.cal_latent(
|
| 1313 |
audio_length_in_s,
|
| 1314 |
time_pooling,
|
| 1315 |
freq_pooling,
|
|
@@ -1335,14 +1264,12 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1335 |
self.unet.set_attn_processor(original_processor)
|
| 1336 |
file_path = os.path.join(self.output_path, f"{0:02d}.wav")
|
| 1337 |
scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
|
| 1338 |
-
|
| 1339 |
if self.use_lora:
|
| 1340 |
self.unet = load_lora(
|
| 1341 |
self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
|
| 1342 |
attn_processor_dict = {}
|
| 1343 |
for k in self.unet.attn_processors.keys():
|
| 1344 |
if do_replace_attn(k):
|
| 1345 |
-
# print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
|
| 1346 |
if self.use_lora:
|
| 1347 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
| 1348 |
self.aud2_dict, k)
|
|
@@ -1351,8 +1278,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1351 |
self.aud2_dict, k)
|
| 1352 |
else:
|
| 1353 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
| 1354 |
-
|
| 1355 |
-
last_audio = self.cal_latent(
|
| 1356 |
audio_length_in_s,
|
| 1357 |
time_pooling,
|
| 1358 |
freq_pooling,
|
|
@@ -1376,6 +1302,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1376 |
)
|
| 1377 |
file_path = os.path.join(self.output_path, f"{num_frames-1:02d}.wav")
|
| 1378 |
scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
|
|
|
|
| 1379 |
self.unet.set_attn_processor(original_processor)
|
| 1380 |
|
| 1381 |
for i in tqdm(range(1, num_frames - 1), desc=desc):
|
|
@@ -1395,8 +1322,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
|
|
| 1395 |
original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
|
| 1396 |
else:
|
| 1397 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
| 1398 |
-
|
| 1399 |
-
audio = self.cal_latent(
|
| 1400 |
audio_length_in_s,
|
| 1401 |
time_pooling,
|
| 1402 |
freq_pooling,
|
|
|
|
| 49 |
import librosa
|
| 50 |
import warnings
|
| 51 |
import matplotlib.pyplot as plt
|
|
|
|
| 52 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 56 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 57 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 58 |
|
| 59 |
|
| 60 |
def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
|
|
|
|
| 73 |
plt.show()
|
| 74 |
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
class StoreProcessor():
|
| 77 |
def __init__(self, original_processor, value_dict, name):
|
| 78 |
self.original_processor = original_processor
|
|
|
|
| 84 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
| 85 |
# Is self attention
|
| 86 |
if encoder_hidden_states is None:
|
|
|
|
|
|
|
| 87 |
# print(f'In StoreProcessor: {self.name} {self.id}')
|
| 88 |
self.value_dict[self.name][self.id] = hidden_states.detach()
|
| 89 |
self.id += 1
|
|
|
|
| 90 |
res = self.original_processor(attn, hidden_states, *args,
|
| 91 |
encoder_hidden_states=encoder_hidden_states,
|
| 92 |
attention_mask=attention_mask,
|
|
|
|
| 108 |
|
| 109 |
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
| 110 |
# Is self attention
|
|
|
|
| 111 |
if encoder_hidden_states is None:
|
|
|
|
| 112 |
if self.id < 10 * self.lamd:
|
| 113 |
map0 = self.aud1_dict[self.name][self.id]
|
| 114 |
map1 = self.aud2_dict[self.name][self.id]
|
| 115 |
cross_map = self.beta * hidden_states + \
|
| 116 |
(1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
|
|
|
|
| 117 |
res = self.original_processor(attn, hidden_states, *args,
|
| 118 |
encoder_hidden_states=cross_map,
|
| 119 |
attention_mask=attention_mask,
|
| 120 |
**kwargs)
|
| 121 |
else:
|
|
|
|
| 122 |
res = self.original_processor(attn, hidden_states, *args,
|
| 123 |
encoder_hidden_states=encoder_hidden_states,
|
| 124 |
attention_mask=attention_mask,
|
| 125 |
**kwargs)
|
| 126 |
|
| 127 |
self.id += 1
|
|
|
|
| 128 |
if self.id == len(self.aud1_dict[self.name]):
|
| 129 |
self.id = 0
|
| 130 |
else:
|
|
|
|
| 131 |
res = self.original_processor(attn, hidden_states, *args,
|
| 132 |
encoder_hidden_states=encoder_hidden_states,
|
| 133 |
attention_mask=attention_mask,
|
|
|
|
| 843 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
| 844 |
# print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
|
| 845 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
| 846 |
+
model = AudioMAEConditionCTPoolRand().cuda()
|
| 847 |
model.eval()
|
| 848 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
| 849 |
uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
|
|
|
|
| 867 |
|
| 868 |
return prompt_embeds, attention_mask, generated_prompt_embeds
|
| 869 |
|
| 870 |
+
def init_trained_pipeline(self, model_path, device, dtype, ap_scale, text_ap_scale):
|
| 871 |
+
pipeline_trained = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=dtype).to(device)
|
| 872 |
+
layer_num = 0
|
| 873 |
+
cross = [None, None, 768, 768, 1024, 1024, None, None]
|
| 874 |
+
unet = pipeline_trained.unet
|
| 875 |
+
attn_procs = {}
|
| 876 |
+
for name in unet.attn_processors.keys():
|
| 877 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 878 |
+
if name.startswith("mid_block"):
|
| 879 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 880 |
+
elif name.startswith("up_blocks"):
|
| 881 |
+
block_id = int(name[len("up_blocks.")])
|
| 882 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 883 |
+
elif name.startswith("down_blocks"):
|
| 884 |
+
block_id = int(name[len("down_blocks.")])
|
| 885 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 886 |
+
|
| 887 |
+
if cross_attention_dim is None:
|
| 888 |
+
attn_procs[name] = AttnProcessor2_0()
|
| 889 |
+
else:
|
| 890 |
+
cross_attention_dim = cross[layer_num % 8]
|
| 891 |
+
layer_num += 1
|
| 892 |
+
if cross_attention_dim == 768:
|
| 893 |
+
attn_procs[name] = IPAttnProcessor2_0(
|
| 894 |
+
hidden_size=hidden_size,
|
| 895 |
+
name=name,
|
| 896 |
+
flag='trained',
|
| 897 |
+
cross_attention_dim=cross_attention_dim,
|
| 898 |
+
text_scale=text_ap_scale,
|
| 899 |
+
scale=ap_scale,
|
| 900 |
+
num_tokens=8,
|
| 901 |
+
do_copy=False
|
| 902 |
+
).to(device, dtype=dtype)
|
| 903 |
+
else:
|
| 904 |
+
attn_procs[name] = AttnProcessor2_0()
|
| 905 |
+
|
| 906 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 907 |
+
for name, processor in attn_procs.items():
|
| 908 |
+
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
| 909 |
+
weight_name_v = name + ".to_v_ip.weight"
|
| 910 |
+
weight_name_k = name + ".to_k_ip.weight"
|
| 911 |
+
if dtype == torch.float32:
|
| 912 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].float())
|
| 913 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].float())
|
| 914 |
+
elif dtype == torch.float16:
|
| 915 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
|
| 916 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
|
| 917 |
+
unet.set_attn_processor(attn_procs)
|
| 918 |
+
class _Wrapper(AttnProcsLayers):
|
| 919 |
+
def forward(self, *args, **kwargs):
|
| 920 |
+
return unet(*args, **kwargs)
|
| 921 |
+
|
| 922 |
+
unet = _Wrapper(unet.attn_processors)
|
| 923 |
+
|
| 924 |
+
return pipeline_trained
|
| 925 |
+
|
| 926 |
@torch.no_grad()
|
| 927 |
def aud2latent(self, audio_path, audio_length_in_s):
|
| 928 |
DEVICE = torch.device(
|
| 929 |
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
|
| 931 |
output_path = audio_path.replace('.wav', '_fbank.png')
|
| 932 |
visualize_mel_spectrogram(mel_spect_tensor, output_path)
|
|
|
|
| 939 |
@torch.no_grad()
|
| 940 |
def ddim_inversion(self, start_latents, prompt_embeds, attention_mask, generated_prompt_embeds, guidance_scale,num_inference_steps):
|
| 941 |
start_step = 0
|
| 942 |
+
# print(f"Scheduler timesteps: {self.scheduler.timesteps}")
|
| 943 |
+
num_inference_steps = min(num_inference_steps, int(max(self.scheduler.timesteps)))
|
| 944 |
device = start_latents.device
|
| 945 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 946 |
start_latents *= self.scheduler.init_noise_sigma
|
|
|
|
| 959 |
def generate_morphing_prompt(self, prompt_1, prompt_2, alpha):
|
| 960 |
closer_prompt = prompt_1 if alpha <= 0.5 else prompt_2
|
| 961 |
prompt = (
|
| 962 |
+
f"Jazz style music"
|
|
|
|
|
|
|
| 963 |
)
|
| 964 |
return prompt
|
| 965 |
|
|
|
|
| 967 |
def cal_latent(self,audio_length_in_s,time_pooling, freq_pooling,num_inference_steps, guidance_scale, aud_noise_1, aud_noise_2, prompt_1, prompt_2,
|
| 968 |
prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2,
|
| 969 |
alpha, original_processor,attn_processor_dict, use_morph_prompt, morphing_with_lora):
|
| 970 |
+
num_inference_steps = min(num_inference_steps, int(max(self.pipeline_trained.scheduler.timesteps)))
|
| 971 |
latents = slerp(aud_noise_1, aud_noise_2, alpha, self.use_adain)
|
| 972 |
if not use_morph_prompt:
|
| 973 |
+
print("Not using morphing prompt")
|
| 974 |
max_length = max(prompt_embeds_1.shape[1], prompt_embeds_2.shape[1])
|
| 975 |
if prompt_embeds_1.shape[1] < max_length:
|
| 976 |
pad_size = max_length - prompt_embeds_1.shape[1]
|
|
|
|
| 1019 |
# attention_mask = (attention_mask > 0.5).long()
|
| 1020 |
|
| 1021 |
if morphing_with_lora:
|
| 1022 |
+
self.pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
| 1023 |
+
waveform = self.pipeline_trained(
|
| 1024 |
time_pooling= time_pooling,
|
| 1025 |
freq_pooling= freq_pooling,
|
| 1026 |
latents = latents,
|
| 1027 |
num_inference_steps= num_inference_steps,
|
| 1028 |
+
guidance_scale = guidance_scale,
|
| 1029 |
num_waveforms_per_prompt= 1,
|
| 1030 |
audio_length_in_s=audio_length_in_s,
|
| 1031 |
prompt_embeds = prompt_embeds.chunk(2)[1],
|
|
|
|
| 1036 |
negative_attention_mask = attention_mask.chunk(2)[0],
|
| 1037 |
).audios[0]
|
| 1038 |
if morphing_with_lora:
|
| 1039 |
+
self.pipeline_trained.unet.set_attn_processor(original_processor)
|
| 1040 |
else:
|
| 1041 |
latent_model_input = latents
|
| 1042 |
morphing_prompt = self.generate_morphing_prompt(prompt_1, prompt_2, alpha)
|
| 1043 |
if morphing_with_lora:
|
| 1044 |
+
self.pipeline_trained.unet.set_attn_processor(attn_processor_dict)
|
| 1045 |
+
waveform = self.pipeline_trained(
|
| 1046 |
time_pooling= time_pooling,
|
| 1047 |
freq_pooling= freq_pooling,
|
| 1048 |
latents = latent_model_input,
|
|
|
|
| 1054 |
negative_prompt= 'Low quality',
|
| 1055 |
).audios[0]
|
| 1056 |
if morphing_with_lora:
|
| 1057 |
+
self.pipeline_trained.unet.set_attn_processor(original_processor)
|
| 1058 |
|
| 1059 |
+
return waveform, latents
|
| 1060 |
|
| 1061 |
@torch.no_grad()
|
| 1062 |
def __call__(
|
| 1063 |
self,
|
| 1064 |
+
dtype,
|
| 1065 |
audio_file = None,
|
| 1066 |
audio_file2 = None,
|
| 1067 |
+
ap_scale = 1.0,
|
| 1068 |
+
text_ap_scale = 1.0,
|
| 1069 |
save_lora_dir = "./lora",
|
| 1070 |
load_lora_path_1 = None,
|
| 1071 |
load_lora_path_2 = None,
|
|
|
|
| 1089 |
attn_beta=0,
|
| 1090 |
lamd=0.6,
|
| 1091 |
fix_lora=None,
|
|
|
|
| 1092 |
num_frames=50,
|
| 1093 |
max_new_tokens: Optional[int] = None,
|
| 1094 |
callback_steps: Optional[int] = 1,
|
|
|
|
| 1096 |
morphing_with_lora=False,
|
| 1097 |
use_morph_prompt=False,
|
| 1098 |
):
|
| 1099 |
+
ap_adapter_path = 'pytorch_model.bin'
|
| 1100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1101 |
# 0. Load the pre-trained AP-adapter model
|
| 1102 |
layer_num = 0
|
|
|
|
| 1112 |
elif name.startswith("down_blocks"):
|
| 1113 |
block_id = int(name[len("down_blocks.")])
|
| 1114 |
hidden_size = self.unet.config.block_out_channels[block_id]
|
|
|
|
| 1115 |
if cross_attention_dim is None:
|
| 1116 |
attn_procs[name] = AttnProcessor2_0()
|
| 1117 |
else:
|
| 1118 |
cross_attention_dim = cross[layer_num % 8]
|
| 1119 |
layer_num += 1
|
| 1120 |
if cross_attention_dim == 768:
|
| 1121 |
+
attn_procs[name].scale = IPAttnProcessor2_0(
|
| 1122 |
hidden_size=hidden_size,
|
| 1123 |
name=name,
|
| 1124 |
cross_attention_dim=cross_attention_dim,
|
| 1125 |
+
text_scale=100,
|
| 1126 |
+
scale=ap_scale,
|
| 1127 |
num_tokens=8,
|
| 1128 |
do_copy=False
|
| 1129 |
+
).to(device, dtype=dtype)
|
| 1130 |
else:
|
| 1131 |
attn_procs[name] = AttnProcessor2_0()
|
| 1132 |
+
state_dict = torch.load(ap_adapter_path, map_location=device)
|
|
|
|
| 1133 |
for name, processor in attn_procs.items():
|
| 1134 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
| 1135 |
weight_name_v = name + ".to_v_ip.weight"
|
| 1136 |
weight_name_k = name + ".to_k_ip.weight"
|
| 1137 |
+
if dtype == torch.float32:
|
| 1138 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].float())
|
| 1139 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].float())
|
| 1140 |
+
elif dtype == torch.float16:
|
| 1141 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
|
| 1142 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
|
| 1143 |
self.unet.set_attn_processor(attn_procs)
|
| 1144 |
+
self.pipeline_trained = self.init_trained_pipeline(ap_adapter_path, device, dtype, ap_scale, text_ap_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1145 |
|
|
|
|
|
|
|
| 1146 |
# 1. Pre-check
|
| 1147 |
height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
|
| 1148 |
_, _ = self.pre_check(audio_length_in_s, prompt_2, callback_steps, negative_prompt_2)
|
| 1149 |
# print(f"height: {height}, original_waveform_length: {original_waveform_length}") # height: 1000, original_waveform_length: 160000
|
| 1150 |
|
| 1151 |
# # 2. Define call parameters
|
| 1152 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1153 |
do_classifier_free_guidance = guidance_scale > 1.0
|
| 1154 |
self.use_lora = use_lora
|
| 1155 |
self.use_adain = use_adain
|
|
|
|
| 1163 |
weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
|
| 1164 |
load_lora_path_1 = save_lora_dir + "/" + weight_name
|
| 1165 |
if not os.path.exists(load_lora_path_1):
|
| 1166 |
+
train_lora(audio_file, dtype, time_pooling ,freq_pooling ,prompt_1, negative_prompt_1, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
|
| 1167 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
| 1168 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
| 1169 |
print(f"Load from {load_lora_path_1}.")
|
|
|
|
| 1178 |
weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
|
| 1179 |
load_lora_path_2 = save_lora_dir + "/" + weight_name
|
| 1180 |
if not os.path.exists(load_lora_path_2):
|
| 1181 |
+
train_lora(audio_file2, dtype,time_pooling ,freq_pooling ,prompt_2, negative_prompt_2, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
|
| 1182 |
self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
|
| 1183 |
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
| 1184 |
print(f"Load from {load_lora_path_2}.")
|
|
|
|
| 1197 |
|
| 1198 |
|
| 1199 |
# 4. Prepare latent variables
|
| 1200 |
+
# ------- For the first audio file -------
|
| 1201 |
original_processor = list(self.unet.attn_processors.values())[0]
|
|
|
|
| 1202 |
if noisy_latent_with_lora:
|
| 1203 |
self.unet = load_lora(self.unet, lora_1, lora_2, 0)
|
|
|
|
| 1204 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
| 1205 |
audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1206 |
# aud_noise_1 is the noisy latent representation of the audio file 1
|
| 1207 |
+
aud_noise_1 = self.ddim_inversion(audio_latent, prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, guidance_scale, num_inference_steps = num_inference_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1208 |
# After reconstructed the audio file 1, we set the original processor back
|
| 1209 |
if noisy_latent_with_lora:
|
| 1210 |
self.unet.set_attn_processor(original_processor)
|
|
|
|
| 1211 |
|
| 1212 |
+
# ------- For the second audio file -------
|
| 1213 |
if noisy_latent_with_lora:
|
| 1214 |
self.unet = load_lora(self.unet, lora_1, lora_2, 1)
|
|
|
|
| 1215 |
# We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
|
| 1216 |
audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
# aud_noise_2 is the noisy latent representation of the audio file 2
|
| 1218 |
+
aud_noise_2 = self.ddim_inversion(audio_latent, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2, guidance_scale, num_inference_steps = num_inference_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1219 |
if noisy_latent_with_lora:
|
| 1220 |
self.unet.set_attn_processor(original_processor)
|
|
|
|
| 1221 |
# After reconstructed the audio file 1, we set the original processor back
|
| 1222 |
original_processor = list(self.unet.attn_processors.values())[0]
|
|
|
|
|
|
|
| 1223 |
def morph(alpha_list, desc):
|
| 1224 |
audios = []
|
| 1225 |
# if attn_beta is not None:
|
|
|
|
| 1227 |
self.unet = load_lora(
|
| 1228 |
self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
|
| 1229 |
attn_processor_dict = {}
|
|
|
|
| 1230 |
for k in self.unet.attn_processors.keys():
|
| 1231 |
# print(k)
|
| 1232 |
if do_replace_attn(k):
|
|
|
|
| 1233 |
if self.use_lora:
|
| 1234 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
| 1235 |
self.aud1_dict, k)
|
|
|
|
| 1237 |
attn_processor_dict[k] = StoreProcessor(original_processor,
|
| 1238 |
self.aud1_dict, k)
|
| 1239 |
else:
|
| 1240 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
| 1241 |
+
first_audio, first_latents = self.cal_latent(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1242 |
audio_length_in_s,
|
| 1243 |
time_pooling,
|
| 1244 |
freq_pooling,
|
|
|
|
| 1264 |
self.unet.set_attn_processor(original_processor)
|
| 1265 |
file_path = os.path.join(self.output_path, f"{0:02d}.wav")
|
| 1266 |
scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
|
|
|
|
| 1267 |
if self.use_lora:
|
| 1268 |
self.unet = load_lora(
|
| 1269 |
self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
|
| 1270 |
attn_processor_dict = {}
|
| 1271 |
for k in self.unet.attn_processors.keys():
|
| 1272 |
if do_replace_attn(k):
|
|
|
|
| 1273 |
if self.use_lora:
|
| 1274 |
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
| 1275 |
self.aud2_dict, k)
|
|
|
|
| 1278 |
self.aud2_dict, k)
|
| 1279 |
else:
|
| 1280 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
| 1281 |
+
last_audio, last_latents = self.cal_latent(
|
|
|
|
| 1282 |
audio_length_in_s,
|
| 1283 |
time_pooling,
|
| 1284 |
freq_pooling,
|
|
|
|
| 1302 |
)
|
| 1303 |
file_path = os.path.join(self.output_path, f"{num_frames-1:02d}.wav")
|
| 1304 |
scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
|
| 1305 |
+
|
| 1306 |
self.unet.set_attn_processor(original_processor)
|
| 1307 |
|
| 1308 |
for i in tqdm(range(1, num_frames - 1), desc=desc):
|
|
|
|
| 1322 |
original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
|
| 1323 |
else:
|
| 1324 |
attn_processor_dict[k] = self.unet.attn_processors[k]
|
| 1325 |
+
audio, latents = self.cal_latent(
|
|
|
|
| 1326 |
audio_length_in_s,
|
| 1327 |
time_pooling,
|
| 1328 |
freq_pooling,
|
utils/lora_utils_successed_ver1.py
CHANGED
|
@@ -449,7 +449,7 @@ def plot_loss(loss_history, loss_plot_path, lora_steps):
|
|
| 449 |
# lora_steps: number of lora training step
|
| 450 |
# lora_lr: learning rate of lora training
|
| 451 |
# lora_rank: the rank of lora
|
| 452 |
-
def train_lora(audio_path ,
|
| 453 |
text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
|
| 454 |
vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
|
| 455 |
time_pooling = time_pooling
|
|
@@ -534,7 +534,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
| 534 |
scale=1.0,
|
| 535 |
num_tokens=8,
|
| 536 |
do_copy = do_copy
|
| 537 |
-
).to(device, dtype=
|
| 538 |
else:
|
| 539 |
unet_lora_attn_procs[name] = AttnProcessor2_0()
|
| 540 |
unet.set_attn_processor(unet_lora_attn_procs)
|
|
@@ -580,7 +580,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
| 580 |
fbank = torch.zeros((1024, 128))
|
| 581 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
| 582 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
| 583 |
-
model = AudioMAEConditionCTPoolRand().to(device).to(dtype=
|
| 584 |
model.eval()
|
| 585 |
mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
|
| 586 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
|
@@ -599,24 +599,6 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
| 599 |
generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
|
| 600 |
model_dtype = next(unet.parameters()).dtype
|
| 601 |
generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
|
| 602 |
-
|
| 603 |
-
# num_channels_latents = unet.config.in_channels
|
| 604 |
-
# batch_size = 1
|
| 605 |
-
# num_waveforms_per_prompt = 1
|
| 606 |
-
# generator = None
|
| 607 |
-
# latents = None
|
| 608 |
-
# latents = prepare_latents(
|
| 609 |
-
# vae,
|
| 610 |
-
# vocoder,
|
| 611 |
-
# noise_scheduler,
|
| 612 |
-
# batch_size * num_waveforms_per_prompt,
|
| 613 |
-
# num_channels_latents,
|
| 614 |
-
# height,
|
| 615 |
-
# prompt_embeds.dtype,
|
| 616 |
-
# device,
|
| 617 |
-
# generator,
|
| 618 |
-
# latents,
|
| 619 |
-
# )
|
| 620 |
|
| 621 |
loss_history = []
|
| 622 |
if not os.path.exists(save_lora_dir):
|
|
@@ -683,7 +665,7 @@ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_
|
|
| 683 |
safe_serialization=safe_serialization
|
| 684 |
)
|
| 685 |
|
| 686 |
-
def load_lora(unet, lora_0, lora_1, alpha):
|
| 687 |
attn_procs = unet.attn_processors
|
| 688 |
for name, processor in attn_procs.items():
|
| 689 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
|
@@ -691,10 +673,16 @@ def load_lora(unet, lora_0, lora_1, alpha):
|
|
| 691 |
weight_name_k = name + ".to_k_ip.weight"
|
| 692 |
if weight_name_v in lora_0 and weight_name_v in lora_1:
|
| 693 |
v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
|
| 694 |
-
|
|
|
|
|
|
|
|
|
|
| 695 |
|
| 696 |
if weight_name_k in lora_0 and weight_name_k in lora_1:
|
| 697 |
k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
|
| 698 |
-
|
|
|
|
|
|
|
|
|
|
| 699 |
unet.set_attn_processor(attn_procs)
|
| 700 |
return unet
|
|
|
|
| 449 |
# lora_steps: number of lora training step
|
| 450 |
# lora_lr: learning rate of lora training
|
| 451 |
# lora_rank: the rank of lora
|
| 452 |
+
def train_lora(audio_path ,dtype ,time_pooling ,freq_pooling ,prompt, negative_prompt, guidance_scale, save_lora_dir, tokenizer=None, tokenizer_2=None,
|
| 453 |
text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
|
| 454 |
vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
|
| 455 |
time_pooling = time_pooling
|
|
|
|
| 534 |
scale=1.0,
|
| 535 |
num_tokens=8,
|
| 536 |
do_copy = do_copy
|
| 537 |
+
).to(device, dtype=dtype)
|
| 538 |
else:
|
| 539 |
unet_lora_attn_procs[name] = AttnProcessor2_0()
|
| 540 |
unet.set_attn_processor(unet_lora_attn_procs)
|
|
|
|
| 580 |
fbank = torch.zeros((1024, 128))
|
| 581 |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
|
| 582 |
mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
|
| 583 |
+
model = AudioMAEConditionCTPoolRand().to(device).to(dtype=dtype)
|
| 584 |
model.eval()
|
| 585 |
mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
|
| 586 |
LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
|
|
|
|
| 599 |
generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
|
| 600 |
model_dtype = next(unet.parameters()).dtype
|
| 601 |
generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
|
| 603 |
loss_history = []
|
| 604 |
if not os.path.exists(save_lora_dir):
|
|
|
|
| 665 |
safe_serialization=safe_serialization
|
| 666 |
)
|
| 667 |
|
| 668 |
+
def load_lora(unet, lora_0, lora_1, alpha, dtype):
|
| 669 |
attn_procs = unet.attn_processors
|
| 670 |
for name, processor in attn_procs.items():
|
| 671 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
|
|
|
| 673 |
weight_name_k = name + ".to_k_ip.weight"
|
| 674 |
if weight_name_v in lora_0 and weight_name_v in lora_1:
|
| 675 |
v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
|
| 676 |
+
if dtype == torch.float32:
|
| 677 |
+
processor.to_v_ip.weight = torch.nn.Parameter(v_weight.float())
|
| 678 |
+
elif dtype == torch.float16:
|
| 679 |
+
processor.to_v_ip.weight = torch.nn.Parameter(v_weight.half())
|
| 680 |
|
| 681 |
if weight_name_k in lora_0 and weight_name_k in lora_1:
|
| 682 |
k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
|
| 683 |
+
if dtype == torch.float32:
|
| 684 |
+
processor.to_k_ip.weight = torch.nn.Parameter(k_weight.float())
|
| 685 |
+
elif dtype == torch.float16:
|
| 686 |
+
processor.to_k_ip.weight = torch.nn.Parameter(k_weight.half())
|
| 687 |
unet.set_attn_processor(attn_procs)
|
| 688 |
return unet
|