File size: 6,987 Bytes
593d539
bfe81d6
593d539
 
5c245fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03a5c9f
bfe81d6
5c245fa
bfe81d6
5c245fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593d539
5c245fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593d539
 
 
5c245fa
 
 
 
 
 
593d539
5c245fa
 
 
 
 
 
 
 
 
 
 
593d539
5c245fa
 
593d539
5c245fa
 
 
 
 
 
 
 
b02904a
5c245fa
 
 
 
 
8919734
5c245fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02904a
5c245fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8919734
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import gradio as gr
import spaces
import os
import torch
import tempfile
from pathlib import Path

# Try to import DiariZen
try:
    from diarizen.pipelines.inference import DiariZenPipeline
    DIARIZEN_AVAILABLE = True
except ImportError:
    DIARIZEN_AVAILABLE = False
    print("⚠️ DiariZen not available - install from https://github.com/BUTSpeechFIT/DiariZen")

# Model cache
pipeline_cache = {}

def load_diarizen_pipeline(model_id="BUT-FIT/diarizen-wavlm-large-s80-md"):
    """Load DiariZen pipeline with caching"""
    if model_id in pipeline_cache:
        return pipeline_cache[model_id]

    try:
        print(f"Loading DiariZen model: {model_id}")
        pipeline = DiariZenPipeline.from_pretrained(model_id)

        # Move to GPU if available
        if torch.cuda.is_available():
            print("Moving pipeline to CUDA")
            pipeline.to(torch.device("cuda"))

        pipeline_cache[model_id] = pipeline
        print(f"✅ Model loaded successfully")
        return pipeline
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        raise e

def format_diarization_results(annotations):
    """Format diarization results as readable text"""
    results = []
    results.append("# Diarization Results\n\n")
    results.append("| Start Time | End Time | Duration | Speaker |\n")
    results.append("|------------|----------|----------|----------|\n")

    for turn, _, speaker in annotations.itertracks(yield_label=True):
        duration = turn.end - turn.start
        results.append(
            f"| {turn.start:8.2f}s | {turn.end:8.2f}s | {duration:6.2f}s | {speaker} |\n"
        )

    return "".join(results)

def save_rttm(annotations, audio_filename):
    """Save annotations to RTTM format"""
    # Create temporary directory for RTTM
    temp_dir = tempfile.mkdtemp()
    rttm_path = Path(temp_dir) / f"{audio_filename}.rttm"

    with open(rttm_path, 'w') as f:
        for turn, _, speaker in annotations.itertracks(yield_label=True):
            duration = turn.end - turn.start
            # RTTM format: SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker> <NA> <NA>
            f.write(f"SPEAKER {audio_filename} 1 {turn.start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n")

    return str(rttm_path)

@spaces.GPU(duration=120)
def diarize_audio(audio_file, model_choice):
    """Main diarization function with GPU acceleration"""
    if not DIARIZEN_AVAILABLE:
        return "❌ Error: DiariZen not installed. Please install from https://github.com/BUTSpeechFIT/DiariZen", None

    if audio_file is None:
        return "⚠️ Please upload an audio file", None

    try:
        # Map model choice to model ID
        model_map = {
            "WavLM Large (Recommended)": "BUT-FIT/diarizen-wavlm-large-s80-md",
            "WavLM Base (Faster)": "BUT-FIT/diarizen-wavlm-base-s80-md",
            "WavLM Large MLC": "BUT-FIT/diarizen-wavlm-large-s80-mlc"
        }
        model_id = model_map[model_choice]

        # Load pipeline
        pipeline = load_diarizen_pipeline(model_id)

        # Get audio filename
        audio_path = Path(audio_file)
        audio_name = audio_path.stem

        print(f"🎤 Processing audio: {audio_file}")

        # Run diarization
        annotations = pipeline(audio_file)

        print(f"✅ Diarization complete")

        # Format results
        results_text = format_diarization_results(annotations)

        # Save RTTM
        rttm_path = save_rttm(annotations, audio_name)

        return results_text, rttm_path

    except Exception as e:
        error_msg = f"❌ Error during diarization:\n{str(e)}"
        print(error_msg)
        import traceback
        traceback.print_exc()
        return error_msg, None

# Build Gradio Interface
with gr.Blocks(title="DiariZen Speaker Diarization") as demo:
    gr.Markdown("""
    # 🎙️ DiariZen - Speaker Diarization

    **Upload audio → Select model → Run diarization → View results & Download RTTM**

    DiariZen: High-performance speaker diarization toolkit from BUT-FIT
    """)

    if not DIARIZEN_AVAILABLE:
        gr.Markdown("""
        ⚠️ **DiariZen not installed**

        To use this Space, DiariZen must be installed. Please see:
        https://github.com/BUTSpeechFIT/DiariZen
        """)

    with gr.Row():
        with gr.Column():
            # Audio input
            audio_input = gr.Audio(
                label="📤 Upload Audio File",
                type="filepath",
                sources=["upload", "microphone"]
            )

            # Model selection
            model_dropdown = gr.Dropdown(
                choices=[
                    "WavLM Large (Recommended)",
                    "WavLM Base (Faster)",
                    "WavLM Large MLC"
                ],
                value="WavLM Large (Recommended)",
                label="🤖 Select Model",
                info="Choose diarization model"
            )

            # Run button
            run_btn = gr.Button("▶️ Run Diarization", variant="primary", size="lg")

        with gr.Column():
            # Results output
            results_output = gr.Textbox(
                label="📊 Diarization Results",
                lines=20,
                max_lines=30,
                show_copy_button=True
            )

            # RTTM download
            rttm_output = gr.File(
                label="📝 Download RTTM",
                interactive=False
            )

    # Model information
    with gr.Accordion("ℹ️ Model Information", open=False):
        gr.Markdown("""
        ### Available Models

        | Model | Parameters | Speed | Quality | Description |
        |-------|-----------|-------|---------|-------------|
        | WavLM Large | 63M | Fast | High | Recommended for most use cases |
        | WavLM Base | - | Very Fast | Good | Faster variant for quick processing |
        | WavLM Large MLC | 63M | Fast | High | Multi-language optimized |

        ### Performance

        DiariZen substantially outperforms Pyannote v3.1:
        - AMI-SDM: 13.9% DER (vs 22.4% Pyannote)
        - VoxConverse: 9.1% DER (vs 11.3% Pyannote)
        - AISHELL-4: 10.1% DER (vs 12.2% Pyannote)

        ### Citation

        ```bibtex
        @inproceedings{diariZen2024,
          title={DiariZen: A toolkit for speaker diarization},
          author={Han, Ivo and Landini, Federico and Burget, Lukáš and Černocký, Jan},
          booktitle={INTERSPEECH},
          year={2024}
        }
        ```
        """)

    # Footer
    gr.Markdown("""
    ---
    **Source**: [github.com/BUTSpeechFIT/DiariZen](https://github.com/BUTSpeechFIT/DiariZen)

    **License**: MIT (Code) | Research/Non-commercial (Models)
    """)

    # Connect button to function
    run_btn.click(
        fn=diarize_audio,
        inputs=[audio_input, model_dropdown],
        outputs=[results_output, rttm_output]
    )

if __name__ == "__main__":
    demo.launch()