Spaces:
Running
Running
| # This module handles the data loading and preprocessing for various phoneme transcription datasets. | |
| import torch | |
| import torchaudio | |
| import zipfile | |
| from pathlib import Path | |
| # Get absolute path | |
| CURRENT_DIR = Path(__file__).parent.absolute() | |
| # Constants | |
| DATA_DIR = CURRENT_DIR / "data" | |
| TIMIT_PATH = DATA_DIR / "TIMIT.zip" | |
| # Abstract data manager class | |
| class DataManager: | |
| """Abstract class for handling dataset operations""" | |
| def get_file_list(self, subset: str) -> list[str]: | |
| """Get list of files for given subset""" | |
| raise NotImplementedError | |
| def load_audio(self, filename: str) -> torch.Tensor: | |
| """Load and preprocess audio file""" | |
| raise NotImplementedError | |
| def get_phonemes(self, filename: str) -> str: | |
| """Get phoneme sequence from file""" | |
| raise NotImplementedError | |
| # Implement datasets | |
| class TimitDataManager(DataManager): | |
| """Handles all TIMIT dataset operations""" | |
| # TIMIT to IPA mapping with direct simplifications | |
| _TIMIT_TO_IPA = { | |
| # Vowels (simplified) | |
| "aa": "ɑ", | |
| "ae": "æ", | |
| "ah": "ʌ", | |
| "ao": "ɔ", | |
| "aw": "aʊ", | |
| "ay": "aɪ", | |
| "eh": "ɛ", | |
| "er": "ɹ", # Simplified from 'ɝ' | |
| "ey": "eɪ", | |
| "ih": "ɪ", | |
| "ix": "i", # Simplified from 'ɨ' | |
| "iy": "i", | |
| "ow": "oʊ", | |
| "oy": "ɔɪ", | |
| "uh": "ʊ", | |
| "uw": "u", | |
| "ux": "u", # Simplified from 'ʉ' | |
| "ax": "ə", | |
| "ax-h": "ə", # Simplified from 'ə̥' | |
| "axr": "ɹ", # Simplified from 'ɚ' | |
| # Consonants | |
| "b": "", | |
| "bcl": "b", | |
| "d": "", | |
| "dcl": "d", | |
| "g": "", | |
| "gcl": "g", | |
| "p": "", | |
| "pcl": "p", | |
| "t": "", | |
| "tcl": "t", | |
| "k": "", | |
| "kcl": "k", | |
| "dx": "ɾ", | |
| "q": "ʔ", | |
| # Fricatives | |
| "jh": "dʒ", | |
| "ch": "tʃ", | |
| "s": "s", | |
| "sh": "ʃ", | |
| "z": "z", | |
| "zh": "ʒ", | |
| "f": "f", | |
| "th": "θ", | |
| "v": "v", | |
| "dh": "ð", | |
| "hh": "h", | |
| "hv": "h", # Simplified from 'ɦ' | |
| # Nasals (simplified) | |
| "m": "m", | |
| "n": "n", | |
| "ng": "ŋ", | |
| "em": "m", # Simplified from 'm̩' | |
| "en": "n", # Simplified from 'n̩' | |
| "eng": "ŋ", # Simplified from 'ŋ̍' | |
| "nx": "ɾ", # Simplified from 'ɾ̃' | |
| # Semivowels and Glides | |
| "l": "l", | |
| "r": "ɹ", | |
| "w": "w", | |
| "wh": "ʍ", | |
| "y": "j", | |
| "el": "l", # Simplified from 'l̩' | |
| # Special | |
| "epi": "", # Remove epenthetic silence | |
| "h#": "", # Remove start/end silence | |
| "pau": "", # Remove pause | |
| } | |
| def __init__(self, timit_path: Path): | |
| self.timit_path = timit_path | |
| self._zip_ = None | |
| print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}") | |
| if not self.timit_path.exists(): | |
| raise FileNotFoundError( | |
| f"TIMIT dataset not found at {self.timit_path.absolute()}. Try running ./scripts/download_data_lfs.sh again." | |
| ) | |
| else: | |
| print("TIMIT dataset file exists!") | |
| def _zip(self): | |
| if not self._zip_: | |
| self._zip_ = zipfile.ZipFile(self.timit_path, "r") | |
| return self._zip_ | |
| def get_file_list(self, subset: str) -> list[str]: | |
| """Get list of WAV files for given subset""" | |
| files = [ | |
| f | |
| for f in self._zip.namelist() | |
| if f.endswith(".WAV") and subset.lower() in f.lower() | |
| ] | |
| print(f"Found {len(files)} WAV files in {subset} subset") | |
| if files: | |
| print("First 3 files:", files[:3]) | |
| return files | |
| def load_audio(self, filename: str) -> torch.Tensor: | |
| """Load and preprocess audio file""" | |
| with self._zip.open(filename) as wav_file: | |
| waveform, sample_rate = torchaudio.load(wav_file) # type: ignore | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if sample_rate != 16000: | |
| waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) | |
| waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7) | |
| if waveform.dim() == 1: | |
| waveform = waveform.unsqueeze(0) | |
| return waveform | |
| def get_phonemes(self, filename: str) -> str: | |
| """Get cleaned phoneme sequence from PHN file and convert to IPA""" | |
| phn_file = filename.replace(".WAV", ".PHN") | |
| with self._zip.open(phn_file) as f: | |
| phonemes = [] | |
| for line in f.read().decode("utf-8").splitlines(): | |
| if line.strip(): | |
| _, _, phone = line.split() | |
| phone = self._remove_stress_mark(phone) | |
| # Convert to IPA instead of using simplify_timit | |
| ipa = self._TIMIT_TO_IPA.get(phone.lower(), "") | |
| if ipa: | |
| phonemes.append(ipa) | |
| return "".join(phonemes) # Join without spaces for IPA | |
| def _remove_stress_mark(self, text: str) -> str: | |
| """Removes the combining double inverted breve (͡) from text""" | |
| if not isinstance(text, str): | |
| raise TypeError("Input must be string") | |
| return text.replace("͡", "") | |
| # Initialize data managers | |
| timit_manager = TimitDataManager(TIMIT_PATH) | |