|
|
import gradio as gr |
|
|
import spaces |
|
|
import os |
|
|
import torch |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
try: |
|
|
from diarizen.pipelines.inference import DiariZenPipeline |
|
|
DIARIZEN_AVAILABLE = True |
|
|
except ImportError: |
|
|
DIARIZEN_AVAILABLE = False |
|
|
print("⚠️ DiariZen not available - install from https://github.com/BUTSpeechFIT/DiariZen") |
|
|
|
|
|
|
|
|
pipeline_cache = {} |
|
|
|
|
|
def load_diarizen_pipeline(model_id="BUT-FIT/diarizen-wavlm-large-s80-md"): |
|
|
"""Load DiariZen pipeline with caching""" |
|
|
if model_id in pipeline_cache: |
|
|
return pipeline_cache[model_id] |
|
|
|
|
|
try: |
|
|
print(f"Loading DiariZen model: {model_id}") |
|
|
pipeline = DiariZenPipeline.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print("Moving pipeline to CUDA") |
|
|
pipeline.to(torch.device("cuda")) |
|
|
|
|
|
pipeline_cache[model_id] = pipeline |
|
|
print(f"✅ Model loaded successfully") |
|
|
return pipeline |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {e}") |
|
|
raise e |
|
|
|
|
|
def format_diarization_results(annotations): |
|
|
"""Format diarization results as readable text""" |
|
|
results = [] |
|
|
results.append("# Diarization Results\n\n") |
|
|
results.append("| Start Time | End Time | Duration | Speaker |\n") |
|
|
results.append("|------------|----------|----------|----------|\n") |
|
|
|
|
|
for turn, _, speaker in annotations.itertracks(yield_label=True): |
|
|
duration = turn.end - turn.start |
|
|
results.append( |
|
|
f"| {turn.start:8.2f}s | {turn.end:8.2f}s | {duration:6.2f}s | {speaker} |\n" |
|
|
) |
|
|
|
|
|
return "".join(results) |
|
|
|
|
|
def save_rttm(annotations, audio_filename): |
|
|
"""Save annotations to RTTM format""" |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
rttm_path = Path(temp_dir) / f"{audio_filename}.rttm" |
|
|
|
|
|
with open(rttm_path, 'w') as f: |
|
|
for turn, _, speaker in annotations.itertracks(yield_label=True): |
|
|
duration = turn.end - turn.start |
|
|
|
|
|
f.write(f"SPEAKER {audio_filename} 1 {turn.start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n") |
|
|
|
|
|
return str(rttm_path) |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def diarize_audio(audio_file, model_choice): |
|
|
"""Main diarization function with GPU acceleration""" |
|
|
if not DIARIZEN_AVAILABLE: |
|
|
return "❌ Error: DiariZen not installed. Please install from https://github.com/BUTSpeechFIT/DiariZen", None |
|
|
|
|
|
if audio_file is None: |
|
|
return "⚠️ Please upload an audio file", None |
|
|
|
|
|
try: |
|
|
|
|
|
model_map = { |
|
|
"WavLM Large (Recommended)": "BUT-FIT/diarizen-wavlm-large-s80-md", |
|
|
"WavLM Base (Faster)": "BUT-FIT/diarizen-wavlm-base-s80-md", |
|
|
"WavLM Large MLC": "BUT-FIT/diarizen-wavlm-large-s80-mlc" |
|
|
} |
|
|
model_id = model_map[model_choice] |
|
|
|
|
|
|
|
|
pipeline = load_diarizen_pipeline(model_id) |
|
|
|
|
|
|
|
|
audio_path = Path(audio_file) |
|
|
audio_name = audio_path.stem |
|
|
|
|
|
print(f"🎤 Processing audio: {audio_file}") |
|
|
|
|
|
|
|
|
annotations = pipeline(audio_file) |
|
|
|
|
|
print(f"✅ Diarization complete") |
|
|
|
|
|
|
|
|
results_text = format_diarization_results(annotations) |
|
|
|
|
|
|
|
|
rttm_path = save_rttm(annotations, audio_name) |
|
|
|
|
|
return results_text, rttm_path |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Error during diarization:\n{str(e)}" |
|
|
print(error_msg) |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return error_msg, None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="DiariZen Speaker Diarization") as demo: |
|
|
gr.Markdown(""" |
|
|
# 🎙️ DiariZen - Speaker Diarization |
|
|
|
|
|
**Upload audio → Select model → Run diarization → View results & Download RTTM** |
|
|
|
|
|
DiariZen: High-performance speaker diarization toolkit from BUT-FIT |
|
|
""") |
|
|
|
|
|
if not DIARIZEN_AVAILABLE: |
|
|
gr.Markdown(""" |
|
|
⚠️ **DiariZen not installed** |
|
|
|
|
|
To use this Space, DiariZen must be installed. Please see: |
|
|
https://github.com/BUTSpeechFIT/DiariZen |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
audio_input = gr.Audio( |
|
|
label="📤 Upload Audio File", |
|
|
type="filepath", |
|
|
sources=["upload", "microphone"] |
|
|
) |
|
|
|
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=[ |
|
|
"WavLM Large (Recommended)", |
|
|
"WavLM Base (Faster)", |
|
|
"WavLM Large MLC" |
|
|
], |
|
|
value="WavLM Large (Recommended)", |
|
|
label="🤖 Select Model", |
|
|
info="Choose diarization model" |
|
|
) |
|
|
|
|
|
|
|
|
run_btn = gr.Button("▶️ Run Diarization", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
results_output = gr.Textbox( |
|
|
label="📊 Diarization Results", |
|
|
lines=20, |
|
|
max_lines=30, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
|
|
|
rttm_output = gr.File( |
|
|
label="📝 Download RTTM", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("ℹ️ Model Information", open=False): |
|
|
gr.Markdown(""" |
|
|
### Available Models |
|
|
|
|
|
| Model | Parameters | Speed | Quality | Description | |
|
|
|-------|-----------|-------|---------|-------------| |
|
|
| WavLM Large | 63M | Fast | High | Recommended for most use cases | |
|
|
| WavLM Base | - | Very Fast | Good | Faster variant for quick processing | |
|
|
| WavLM Large MLC | 63M | Fast | High | Multi-language optimized | |
|
|
|
|
|
### Performance |
|
|
|
|
|
DiariZen substantially outperforms Pyannote v3.1: |
|
|
- AMI-SDM: 13.9% DER (vs 22.4% Pyannote) |
|
|
- VoxConverse: 9.1% DER (vs 11.3% Pyannote) |
|
|
- AISHELL-4: 10.1% DER (vs 12.2% Pyannote) |
|
|
|
|
|
### Citation |
|
|
|
|
|
```bibtex |
|
|
@inproceedings{diariZen2024, |
|
|
title={DiariZen: A toolkit for speaker diarization}, |
|
|
author={Han, Ivo and Landini, Federico and Burget, Lukáš and Černocký, Jan}, |
|
|
booktitle={INTERSPEECH}, |
|
|
year={2024} |
|
|
} |
|
|
``` |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Source**: [github.com/BUTSpeechFIT/DiariZen](https://github.com/BUTSpeechFIT/DiariZen) |
|
|
|
|
|
**License**: MIT (Code) | Research/Non-commercial (Models) |
|
|
""") |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=diarize_audio, |
|
|
inputs=[audio_input, model_dropdown], |
|
|
outputs=[results_output, rttm_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|