Skip to content

Instantly share code, notes, and snippets.

@polson
Last active July 5, 2025 19:26
Show Gist options
  • Save polson/638fc1d9a4783d9a1637fc4bd21e13d5 to your computer and use it in GitHub Desktop.
Save polson/638fc1d9a4783d9a1637fc4bd21e13d5 to your computer and use it in GitHub Desktop.
A dataloader that creates random mixtures from stems. Useful for datasets like MUSDB18HQ
import random
import random
import sys
import time
import threading
import queue
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import soundfile as sf
import torch
from torch.utils.data import IterableDataset, get_worker_info
from einops import rearrange
@dataclass
class WavMetadata:
path: Path
num_samples: int
@dataclass
class AudioChunk:
file: WavMetadata
start_sample: int
@dataclass
class MixtureMetadata:
chunks: Dict[str, AudioChunk]
@dataclass
class Mixture:
audios: Dict[str, torch.Tensor]
class TrainDataset(IterableDataset):
def __init__(
self,
root_dir: str,
duration_seconds: int,
targets: List[str],
queue_size: int = 100,
):
super().__init__()
self.root_dir = Path(root_dir)
self.chunk_length = duration_seconds * 44100
self.targets = targets
self.queue_size = queue_size
self.mixture_queue = queue.Queue(maxsize=queue_size)
self.stop_generating = threading.Event()
self.generate_mixture_thread = None
self.wav_groups = self.group_wavs_by_name()
self.wav_types = list(self.wav_groups.keys())
def group_wavs_by_name(self) -> Dict[str, List[WavMetadata]]:
wav_groups = defaultdict(list)
root_path = self.root_dir
for wav in root_path.rglob("*.wav"):
if wav.name == "mixture.wav":
continue
base_name = wav.stem
wav_metadata = WavMetadata(
path=wav,
num_samples=sf.info(wav).frames
)
wav_groups[base_name].append(wav_metadata)
return wav_groups
def create_mixture(self, metadata: MixtureMetadata) -> Mixture | None:
audios = {}
for chunk_name, chunk in metadata.chunks.items():
try:
audio, _ = sf.read(chunk.file.path,
start=chunk.start_sample,
stop=chunk.start_sample + self.chunk_length)
if not np.any(np.abs(audio) > 0.001):
return None
audios[chunk_name] = audio
except Exception as e:
print(f"Warning: Error reading audio chunk: {e}")
return None
return Mixture(audios=audios)
def generate_mixtures(self):
while not self.stop_generating.is_set():
try:
mixture_metadata = MixtureMetadata(chunks={
wav_type: AudioChunk(
file=(wav := random.choice(self.wav_groups[wav_type])),
start_sample=random.randint(0, max(0, wav.num_samples - self.chunk_length))
)
for wav_type in self.wav_types
})
mixture = self.create_mixture(mixture_metadata)
if mixture is None:
continue
try:
self.mixture_queue.put(mixture, timeout=1.0)
except queue.Full:
continue
except Exception as e:
print(f"Warning: Error generating mixture: {e}")
time.sleep(0.1)
def start_generating_mixtures(self):
if self.generate_mixture_thread is None or not self.generate_mixture_thread.is_alive():
self.stop_generating.clear()
self.generate_mixture_thread = threading.Thread(target=self.generate_mixtures, daemon=True)
self.generate_mixture_thread.start()
def stop_generating_mixtures(self):
if self.generate_mixture_thread is not None:
self.stop_generating.set()
self.generate_mixture_thread.join(timeout=5.0)
def get_mixture_and_targets_tensor(self, mixture: Mixture) -> Tuple[
torch.Tensor, torch.Tensor]:
audio_tensors = {}
for type, audio in mixture.audios.items():
audio_tensor = torch.from_numpy(audio).float()
audio_tensor = rearrange(audio_tensor, 't c -> c t')
audio_tensors[type] = audio_tensor
mixture_audio = torch.stack(list(audio_tensors.values())).sum(dim=0)
target_audios = []
for target_name in self.targets:
target_audio = audio_tensors[target_name]
target_audios.append(target_audio)
targets_tensor = torch.stack(target_audios)
return mixture_audio, targets_tensor
def __iter__(self):
worker_info = get_worker_info()
if worker_info is not None:
worker_seed = worker_info.id + int(time.time() * 1000000) % 2 ** 32
random.seed(worker_seed)
np.random.seed(worker_seed)
self.start_generating_mixtures()
while True:
try:
mixture = self.mixture_queue.get(timeout=5.0)
mixture_audio, targets_tensor = self.get_mixture_and_targets_tensor(mixture)
yield mixture_audio, targets_tensor
except queue.Empty:
print("Warning: Queue is empty, waiting for mixtures...")
continue
except Exception as e:
print(f"Warning: Skipping sample due to error: {e}")
continue
def __del__(self):
self.stop_generating_mixtures()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment