import gradio as gr import spaces import os import torch import tempfile from pathlib import Path # Try to import DiariZen try: from diarizen.pipelines.inference import DiariZenPipeline DIARIZEN_AVAILABLE = True except ImportError: DIARIZEN_AVAILABLE = False print("⚠️ DiariZen not available - install from https://github.com/BUTSpeechFIT/DiariZen") # Model cache pipeline_cache = {} def load_diarizen_pipeline(model_id="BUT-FIT/diarizen-wavlm-large-s80-md"): """Load DiariZen pipeline with caching""" if model_id in pipeline_cache: return pipeline_cache[model_id] try: print(f"Loading DiariZen model: {model_id}") pipeline = DiariZenPipeline.from_pretrained(model_id) # Move to GPU if available if torch.cuda.is_available(): print("Moving pipeline to CUDA") pipeline.to(torch.device("cuda")) pipeline_cache[model_id] = pipeline print(f"✅ Model loaded successfully") return pipeline except Exception as e: print(f"❌ Error loading model: {e}") raise e def format_diarization_results(annotations): """Format diarization results as readable text""" results = [] results.append("# Diarization Results\n\n") results.append("| Start Time | End Time | Duration | Speaker |\n") results.append("|------------|----------|----------|----------|\n") for turn, _, speaker in annotations.itertracks(yield_label=True): duration = turn.end - turn.start results.append( f"| {turn.start:8.2f}s | {turn.end:8.2f}s | {duration:6.2f}s | {speaker} |\n" ) return "".join(results) def save_rttm(annotations, audio_filename): """Save annotations to RTTM format""" # Create temporary directory for RTTM temp_dir = tempfile.mkdtemp() rttm_path = Path(temp_dir) / f"{audio_filename}.rttm" with open(rttm_path, 'w') as f: for turn, _, speaker in annotations.itertracks(yield_label=True): duration = turn.end - turn.start # RTTM format: SPEAKER 1 f.write(f"SPEAKER {audio_filename} 1 {turn.start:.3f} {duration:.3f} {speaker} \n") return str(rttm_path) @spaces.GPU(duration=120) def diarize_audio(audio_file, model_choice): """Main diarization function with GPU acceleration""" if not DIARIZEN_AVAILABLE: return "❌ Error: DiariZen not installed. Please install from https://github.com/BUTSpeechFIT/DiariZen", None if audio_file is None: return "⚠️ Please upload an audio file", None try: # Map model choice to model ID model_map = { "WavLM Large (Recommended)": "BUT-FIT/diarizen-wavlm-large-s80-md", "WavLM Base (Faster)": "BUT-FIT/diarizen-wavlm-base-s80-md", "WavLM Large MLC": "BUT-FIT/diarizen-wavlm-large-s80-mlc" } model_id = model_map[model_choice] # Load pipeline pipeline = load_diarizen_pipeline(model_id) # Get audio filename audio_path = Path(audio_file) audio_name = audio_path.stem print(f"🎤 Processing audio: {audio_file}") # Run diarization annotations = pipeline(audio_file) print(f"✅ Diarization complete") # Format results results_text = format_diarization_results(annotations) # Save RTTM rttm_path = save_rttm(annotations, audio_name) return results_text, rttm_path except Exception as e: error_msg = f"❌ Error during diarization:\n{str(e)}" print(error_msg) import traceback traceback.print_exc() return error_msg, None # Build Gradio Interface with gr.Blocks(title="DiariZen Speaker Diarization") as demo: gr.Markdown(""" # 🎙️ DiariZen - Speaker Diarization **Upload audio → Select model → Run diarization → View results & Download RTTM** DiariZen: High-performance speaker diarization toolkit from BUT-FIT """) if not DIARIZEN_AVAILABLE: gr.Markdown(""" ⚠️ **DiariZen not installed** To use this Space, DiariZen must be installed. Please see: https://github.com/BUTSpeechFIT/DiariZen """) with gr.Row(): with gr.Column(): # Audio input audio_input = gr.Audio( label="📤 Upload Audio File", type="filepath", sources=["upload", "microphone"] ) # Model selection model_dropdown = gr.Dropdown( choices=[ "WavLM Large (Recommended)", "WavLM Base (Faster)", "WavLM Large MLC" ], value="WavLM Large (Recommended)", label="🤖 Select Model", info="Choose diarization model" ) # Run button run_btn = gr.Button("▶️ Run Diarization", variant="primary", size="lg") with gr.Column(): # Results output results_output = gr.Textbox( label="📊 Diarization Results", lines=20, max_lines=30, show_copy_button=True ) # RTTM download rttm_output = gr.File( label="📝 Download RTTM", interactive=False ) # Model information with gr.Accordion("ℹ️ Model Information", open=False): gr.Markdown(""" ### Available Models | Model | Parameters | Speed | Quality | Description | |-------|-----------|-------|---------|-------------| | WavLM Large | 63M | Fast | High | Recommended for most use cases | | WavLM Base | - | Very Fast | Good | Faster variant for quick processing | | WavLM Large MLC | 63M | Fast | High | Multi-language optimized | ### Performance DiariZen substantially outperforms Pyannote v3.1: - AMI-SDM: 13.9% DER (vs 22.4% Pyannote) - VoxConverse: 9.1% DER (vs 11.3% Pyannote) - AISHELL-4: 10.1% DER (vs 12.2% Pyannote) ### Citation ```bibtex @inproceedings{diariZen2024, title={DiariZen: A toolkit for speaker diarization}, author={Han, Ivo and Landini, Federico and Burget, Lukáš and Černocký, Jan}, booktitle={INTERSPEECH}, year={2024} } ``` """) # Footer gr.Markdown(""" --- **Source**: [github.com/BUTSpeechFIT/DiariZen](https://github.com/BUTSpeechFIT/DiariZen) **License**: MIT (Code) | Research/Non-commercial (Models) """) # Connect button to function run_btn.click( fn=diarize_audio, inputs=[audio_input, model_dropdown], outputs=[results_output, rttm_output] ) if __name__ == "__main__": demo.launch()