nambn0321 commited on
Commit
203c275
·
verified ·
1 Parent(s): 657fb95

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import torchaudio
4
+ # from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
5
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech
6
+ from transformers.models.speecht5 import SpeechT5HifiGan
7
+
8
+
9
+ # Load model and processor
10
+ processor = SpeechT5Processor.from_pretrained("nambn0321/TTS_with_T5_4")
11
+ model = SpeechT5ForTextToSpeech.from_pretrained(
12
+ "nambn0321/TTS_with_T5_4",
13
+ use_safetensors=True,
14
+ trust_remote_code=True
15
+ )
16
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
17
+
18
+ # Move to CUDA if available
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model = model.to(device)
21
+ vocoder = vocoder.to(device)
22
+
23
+
24
+ def tts_generate(text):
25
+ print(f" Input text: {text}")
26
+ try:
27
+ # Preprocess input
28
+ print(" Processing input...")
29
+ inputs = processor(text=text, return_tensors="pt").to(device)
30
+ print(" Text processed.")
31
+
32
+ # Generate waveform directly (with vocoder)
33
+ print("🎤 Generating speech waveform...")
34
+ with torch.no_grad():
35
+ waveform = model.generate_speech(
36
+ inputs["input_ids"],
37
+ vocoder=vocoder
38
+ )
39
+ print(" Waveform generated.")
40
+
41
+ # Save waveform
42
+ output_path = "output.wav"
43
+ if waveform.dim() == 1:
44
+ waveform = waveform.unsqueeze(0)
45
+ torchaudio.save(output_path, waveform.cpu(), sample_rate=16000)
46
+ print(f" Audio saved to {output_path}")
47
+
48
+ return output_path
49
+
50
+ except Exception as e:
51
+ print(" Error during TTS generation:", e)
52
+ return "Error during speech synthesis."
53
+
54
+ # Gradio interface
55
+ demo = gr.Interface(
56
+ fn=tts_generate,
57
+ inputs=gr.Textbox(label="Enter text"),
58
+ outputs=gr.Audio(label="Generated Speech", type="filepath"),
59
+ title="SpeechT5 Text-to-Speech",
60
+ description="Enter text and hear it spoken with SpeechT5 + HiFi-GAN vocoder."
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ print(" Launching Gradio demo...")
65
+ demo.launch()
66
+
67
+