nambn0321 commited on
Commit
3146cac
·
verified ·
1 Parent(s): 8131e32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -1,11 +1,9 @@
1
  import torch
2
  import gradio as gr
3
  import torchaudio
4
- # from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
5
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech
6
  from transformers.models.speecht5 import SpeechT5HifiGan
7
 
8
-
9
  # Load model and processor
10
  processor = SpeechT5Processor.from_pretrained("nambn0321/T5_british")
11
  model = SpeechT5ForTextToSpeech.from_pretrained(
@@ -20,35 +18,37 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
20
  model = model.to(device)
21
  vocoder = vocoder.to(device)
22
 
23
-
24
  def tts_generate(text):
25
- print(f" Input text: {text}")
26
  try:
27
  # Preprocess input
28
- print(" Processing input...")
29
  inputs = processor(text=text, return_tensors="pt").to(device)
30
- print(" Text processed.")
 
 
 
 
 
 
31
 
32
- # Generate waveform directly (with vocoder)
33
- print("🎤 Generating speech waveform...")
34
  with torch.no_grad():
35
- waveform = model.generate_speech(
36
- inputs["input_ids"],
37
- vocoder=vocoder
38
- )
39
- print(" Waveform generated.")
40
 
41
  # Save waveform
42
  output_path = "output.wav"
43
  if waveform.dim() == 1:
44
  waveform = waveform.unsqueeze(0)
45
  torchaudio.save(output_path, waveform.cpu(), sample_rate=16000)
46
- print(f" Audio saved to {output_path}")
47
 
48
  return output_path
49
 
50
  except Exception as e:
51
- print(" Error during TTS generation:", e)
52
  return "Error during speech synthesis."
53
 
54
  # Gradio interface
@@ -61,7 +61,5 @@ demo = gr.Interface(
61
  )
62
 
63
  if __name__ == "__main__":
64
- print(" Launching Gradio demo...")
65
  demo.launch()
66
-
67
-
 
1
  import torch
2
  import gradio as gr
3
  import torchaudio
 
4
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech
5
  from transformers.models.speecht5 import SpeechT5HifiGan
6
 
 
7
  # Load model and processor
8
  processor = SpeechT5Processor.from_pretrained("nambn0321/T5_british")
9
  model = SpeechT5ForTextToSpeech.from_pretrained(
 
18
  model = model.to(device)
19
  vocoder = vocoder.to(device)
20
 
 
21
  def tts_generate(text):
22
+ print(f"Input text: {text}")
23
  try:
24
  # Preprocess input
25
+ print("Processing input...")
26
  inputs = processor(text=text, return_tensors="pt").to(device)
27
+ print("Text processed.")
28
+
29
+ # Generate mel spectrogram with the TTS model (instead of using .generate_speech directly)
30
+ print("🎤 Generating mel spectrogram...")
31
+ with torch.no_grad():
32
+ mel_output, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
33
+ print("Mel spectrogram generated.")
34
 
35
+ # Vocoder to generate waveform from mel spectrogram
36
+ print("🎤 Vocoding to waveform...")
37
  with torch.no_grad():
38
+ waveform = vocoder.decode(mel_output)
39
+ print("Waveform generated.")
 
 
 
40
 
41
  # Save waveform
42
  output_path = "output.wav"
43
  if waveform.dim() == 1:
44
  waveform = waveform.unsqueeze(0)
45
  torchaudio.save(output_path, waveform.cpu(), sample_rate=16000)
46
+ print(f"Audio saved to {output_path}")
47
 
48
  return output_path
49
 
50
  except Exception as e:
51
+ print("Error during TTS generation:", e)
52
  return "Error during speech synthesis."
53
 
54
  # Gradio interface
 
61
  )
62
 
63
  if __name__ == "__main__":
64
+ print("Launching Gradio demo...")
65
  demo.launch()