gryannote / app.py
ahmad walidurosyad
Switch to Gradio SDK with DiariZen git install in requirements.txt
bfe81d6
import gradio as gr
import spaces
import os
import torch
import tempfile
from pathlib import Path
# Try to import DiariZen
try:
from diarizen.pipelines.inference import DiariZenPipeline
DIARIZEN_AVAILABLE = True
except ImportError:
DIARIZEN_AVAILABLE = False
print("⚠️ DiariZen not available - install from https://github.com/BUTSpeechFIT/DiariZen")
# Model cache
pipeline_cache = {}
def load_diarizen_pipeline(model_id="BUT-FIT/diarizen-wavlm-large-s80-md"):
"""Load DiariZen pipeline with caching"""
if model_id in pipeline_cache:
return pipeline_cache[model_id]
try:
print(f"Loading DiariZen model: {model_id}")
pipeline = DiariZenPipeline.from_pretrained(model_id)
# Move to GPU if available
if torch.cuda.is_available():
print("Moving pipeline to CUDA")
pipeline.to(torch.device("cuda"))
pipeline_cache[model_id] = pipeline
print(f"✅ Model loaded successfully")
return pipeline
except Exception as e:
print(f"❌ Error loading model: {e}")
raise e
def format_diarization_results(annotations):
"""Format diarization results as readable text"""
results = []
results.append("# Diarization Results\n\n")
results.append("| Start Time | End Time | Duration | Speaker |\n")
results.append("|------------|----------|----------|----------|\n")
for turn, _, speaker in annotations.itertracks(yield_label=True):
duration = turn.end - turn.start
results.append(
f"| {turn.start:8.2f}s | {turn.end:8.2f}s | {duration:6.2f}s | {speaker} |\n"
)
return "".join(results)
def save_rttm(annotations, audio_filename):
"""Save annotations to RTTM format"""
# Create temporary directory for RTTM
temp_dir = tempfile.mkdtemp()
rttm_path = Path(temp_dir) / f"{audio_filename}.rttm"
with open(rttm_path, 'w') as f:
for turn, _, speaker in annotations.itertracks(yield_label=True):
duration = turn.end - turn.start
# RTTM format: SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker> <NA> <NA>
f.write(f"SPEAKER {audio_filename} 1 {turn.start:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n")
return str(rttm_path)
@spaces.GPU(duration=120)
def diarize_audio(audio_file, model_choice):
"""Main diarization function with GPU acceleration"""
if not DIARIZEN_AVAILABLE:
return "❌ Error: DiariZen not installed. Please install from https://github.com/BUTSpeechFIT/DiariZen", None
if audio_file is None:
return "⚠️ Please upload an audio file", None
try:
# Map model choice to model ID
model_map = {
"WavLM Large (Recommended)": "BUT-FIT/diarizen-wavlm-large-s80-md",
"WavLM Base (Faster)": "BUT-FIT/diarizen-wavlm-base-s80-md",
"WavLM Large MLC": "BUT-FIT/diarizen-wavlm-large-s80-mlc"
}
model_id = model_map[model_choice]
# Load pipeline
pipeline = load_diarizen_pipeline(model_id)
# Get audio filename
audio_path = Path(audio_file)
audio_name = audio_path.stem
print(f"🎤 Processing audio: {audio_file}")
# Run diarization
annotations = pipeline(audio_file)
print(f"✅ Diarization complete")
# Format results
results_text = format_diarization_results(annotations)
# Save RTTM
rttm_path = save_rttm(annotations, audio_name)
return results_text, rttm_path
except Exception as e:
error_msg = f"❌ Error during diarization:\n{str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return error_msg, None
# Build Gradio Interface
with gr.Blocks(title="DiariZen Speaker Diarization") as demo:
gr.Markdown("""
# 🎙️ DiariZen - Speaker Diarization
**Upload audio → Select model → Run diarization → View results & Download RTTM**
DiariZen: High-performance speaker diarization toolkit from BUT-FIT
""")
if not DIARIZEN_AVAILABLE:
gr.Markdown("""
⚠️ **DiariZen not installed**
To use this Space, DiariZen must be installed. Please see:
https://github.com/BUTSpeechFIT/DiariZen
""")
with gr.Row():
with gr.Column():
# Audio input
audio_input = gr.Audio(
label="📤 Upload Audio File",
type="filepath",
sources=["upload", "microphone"]
)
# Model selection
model_dropdown = gr.Dropdown(
choices=[
"WavLM Large (Recommended)",
"WavLM Base (Faster)",
"WavLM Large MLC"
],
value="WavLM Large (Recommended)",
label="🤖 Select Model",
info="Choose diarization model"
)
# Run button
run_btn = gr.Button("▶️ Run Diarization", variant="primary", size="lg")
with gr.Column():
# Results output
results_output = gr.Textbox(
label="📊 Diarization Results",
lines=20,
max_lines=30,
show_copy_button=True
)
# RTTM download
rttm_output = gr.File(
label="📝 Download RTTM",
interactive=False
)
# Model information
with gr.Accordion("ℹ️ Model Information", open=False):
gr.Markdown("""
### Available Models
| Model | Parameters | Speed | Quality | Description |
|-------|-----------|-------|---------|-------------|
| WavLM Large | 63M | Fast | High | Recommended for most use cases |
| WavLM Base | - | Very Fast | Good | Faster variant for quick processing |
| WavLM Large MLC | 63M | Fast | High | Multi-language optimized |
### Performance
DiariZen substantially outperforms Pyannote v3.1:
- AMI-SDM: 13.9% DER (vs 22.4% Pyannote)
- VoxConverse: 9.1% DER (vs 11.3% Pyannote)
- AISHELL-4: 10.1% DER (vs 12.2% Pyannote)
### Citation
```bibtex
@inproceedings{diariZen2024,
title={DiariZen: A toolkit for speaker diarization},
author={Han, Ivo and Landini, Federico and Burget, Lukáš and Černocký, Jan},
booktitle={INTERSPEECH},
year={2024}
}
```
""")
# Footer
gr.Markdown("""
---
**Source**: [github.com/BUTSpeechFIT/DiariZen](https://github.com/BUTSpeechFIT/DiariZen)
**License**: MIT (Code) | Research/Non-commercial (Models)
""")
# Connect button to function
run_btn.click(
fn=diarize_audio,
inputs=[audio_input, model_dropdown],
outputs=[results_output, rttm_output]
)
if __name__ == "__main__":
demo.launch()