Last active
July 5, 2025 19:26
-
-
Save polson/638fc1d9a4783d9a1637fc4bd21e13d5 to your computer and use it in GitHub Desktop.
A dataloader that creates random mixtures from stems. Useful for datasets like MUSDB18HQ
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import random | |
import sys | |
import time | |
import threading | |
import queue | |
from collections import defaultdict | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Dict, List, Optional, Tuple | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from torch.utils.data import IterableDataset, get_worker_info | |
from einops import rearrange | |
@dataclass | |
class WavMetadata: | |
path: Path | |
num_samples: int | |
@dataclass | |
class AudioChunk: | |
file: WavMetadata | |
start_sample: int | |
@dataclass | |
class MixtureMetadata: | |
chunks: Dict[str, AudioChunk] | |
@dataclass | |
class Mixture: | |
audios: Dict[str, torch.Tensor] | |
class TrainDataset(IterableDataset): | |
def __init__( | |
self, | |
root_dir: str, | |
duration_seconds: int, | |
targets: List[str], | |
queue_size: int = 100, | |
): | |
super().__init__() | |
self.root_dir = Path(root_dir) | |
self.chunk_length = duration_seconds * 44100 | |
self.targets = targets | |
self.queue_size = queue_size | |
self.mixture_queue = queue.Queue(maxsize=queue_size) | |
self.stop_generating = threading.Event() | |
self.generate_mixture_thread = None | |
self.wav_groups = self.group_wavs_by_name() | |
self.wav_types = list(self.wav_groups.keys()) | |
def group_wavs_by_name(self) -> Dict[str, List[WavMetadata]]: | |
wav_groups = defaultdict(list) | |
root_path = self.root_dir | |
for wav in root_path.rglob("*.wav"): | |
if wav.name == "mixture.wav": | |
continue | |
base_name = wav.stem | |
wav_metadata = WavMetadata( | |
path=wav, | |
num_samples=sf.info(wav).frames | |
) | |
wav_groups[base_name].append(wav_metadata) | |
return wav_groups | |
def create_mixture(self, metadata: MixtureMetadata) -> Mixture | None: | |
audios = {} | |
for chunk_name, chunk in metadata.chunks.items(): | |
try: | |
audio, _ = sf.read(chunk.file.path, | |
start=chunk.start_sample, | |
stop=chunk.start_sample + self.chunk_length) | |
if not np.any(np.abs(audio) > 0.001): | |
return None | |
audios[chunk_name] = audio | |
except Exception as e: | |
print(f"Warning: Error reading audio chunk: {e}") | |
return None | |
return Mixture(audios=audios) | |
def generate_mixtures(self): | |
while not self.stop_generating.is_set(): | |
try: | |
mixture_metadata = MixtureMetadata(chunks={ | |
wav_type: AudioChunk( | |
file=(wav := random.choice(self.wav_groups[wav_type])), | |
start_sample=random.randint(0, max(0, wav.num_samples - self.chunk_length)) | |
) | |
for wav_type in self.wav_types | |
}) | |
mixture = self.create_mixture(mixture_metadata) | |
if mixture is None: | |
continue | |
try: | |
self.mixture_queue.put(mixture, timeout=1.0) | |
except queue.Full: | |
continue | |
except Exception as e: | |
print(f"Warning: Error generating mixture: {e}") | |
time.sleep(0.1) | |
def start_generating_mixtures(self): | |
if self.generate_mixture_thread is None or not self.generate_mixture_thread.is_alive(): | |
self.stop_generating.clear() | |
self.generate_mixture_thread = threading.Thread(target=self.generate_mixtures, daemon=True) | |
self.generate_mixture_thread.start() | |
def stop_generating_mixtures(self): | |
if self.generate_mixture_thread is not None: | |
self.stop_generating.set() | |
self.generate_mixture_thread.join(timeout=5.0) | |
def get_mixture_and_targets_tensor(self, mixture: Mixture) -> Tuple[ | |
torch.Tensor, torch.Tensor]: | |
audio_tensors = {} | |
for type, audio in mixture.audios.items(): | |
audio_tensor = torch.from_numpy(audio).float() | |
audio_tensor = rearrange(audio_tensor, 't c -> c t') | |
audio_tensors[type] = audio_tensor | |
mixture_audio = torch.stack(list(audio_tensors.values())).sum(dim=0) | |
target_audios = [] | |
for target_name in self.targets: | |
target_audio = audio_tensors[target_name] | |
target_audios.append(target_audio) | |
targets_tensor = torch.stack(target_audios) | |
return mixture_audio, targets_tensor | |
def __iter__(self): | |
worker_info = get_worker_info() | |
if worker_info is not None: | |
worker_seed = worker_info.id + int(time.time() * 1000000) % 2 ** 32 | |
random.seed(worker_seed) | |
np.random.seed(worker_seed) | |
self.start_generating_mixtures() | |
while True: | |
try: | |
mixture = self.mixture_queue.get(timeout=5.0) | |
mixture_audio, targets_tensor = self.get_mixture_and_targets_tensor(mixture) | |
yield mixture_audio, targets_tensor | |
except queue.Empty: | |
print("Warning: Queue is empty, waiting for mixtures...") | |
continue | |
except Exception as e: | |
print(f"Warning: Skipping sample due to error: {e}") | |
continue | |
def __del__(self): | |
self.stop_generating_mixtures() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment