|
#!/usr/bin/env python3 |
|
# audioprep (lite) — on-screen teaching only, no doc markdown in file |
|
|
|
import argparse, sys, random |
|
from pathlib import Path |
|
from typing import Optional, List, Tuple, Dict |
|
|
|
import torch |
|
import torchaudio as ta |
|
import torchaudio.transforms as T |
|
from torch.nn import functional as F |
|
|
|
from rich.console import Console |
|
from rich.theme import Theme |
|
from rich.table import Table |
|
from rich.panel import Panel |
|
from rich.progress import Progress, BarColumn, TimeRemainingColumn, TimeElapsedColumn, MofNCompleteColumn |
|
|
|
console = Console(theme=Theme({ |
|
"ok": "bold green", |
|
"warn": "bold yellow", |
|
"err": "bold red", |
|
"muted": "dim", |
|
"title": "bold cyan", |
|
"info": "white", |
|
})) |
|
|
|
def list_audio(root: Path) -> List[Path]: |
|
exts = {".wav",".flac",".mp3",".ogg",".m4a",".aac"} |
|
return [p for p in root.rglob("*") if p.suffix.lower() in exts and p.is_file()] |
|
|
|
def class_of(root: Path, path: Path) -> str: |
|
try: |
|
return path.relative_to(root).parts[0] |
|
except Exception: |
|
return "unknown" |
|
|
|
def pad_trim(wav: torch.Tensor, target_len: int) -> torch.Tensor: |
|
n = wav.size(1) |
|
if n > target_len: return wav[:, :target_len] |
|
if n < target_len: return F.pad(wav, (0, target_len - n)) |
|
return wav |
|
|
|
def time_shift(wav: torch.Tensor, sr: int, max_ms: int) -> torch.Tensor: |
|
if max_ms <= 0: return wav |
|
shift = int(random.uniform(-max_ms, max_ms) * sr / 1000.0) |
|
return torch.roll(wav, shifts=shift, dims=1) if shift else wav |
|
|
|
def choose_bg(bg_files: List[Path]) -> Optional[Path]: |
|
return random.choice(bg_files) if bg_files else None |
|
|
|
def resample_if_needed(wav: torch.Tensor, sr0: int, sr: int) -> torch.Tensor: |
|
if sr0 == sr: return wav |
|
return ta.functional.resample(wav, sr0, sr) |
|
|
|
def mix_bg(wav: torch.Tensor, bg: torch.Tensor, snr_db: float) -> torch.Tensor: |
|
s = wav.std().clamp_min(1e-6) |
|
n = bg.std().clamp_min(1e-6) |
|
alpha = s / (10 ** (snr_db / 20) * n) |
|
return torch.clamp(wav + alpha * bg, -1.0, 1.0) |
|
|
|
def load_mono(path: Path, sr: int) -> Tuple[torch.Tensor, int]: |
|
wav, sr0 = ta.load(path) # normalized to [-1,1] |
|
if wav.size(0) > 1: |
|
wav = wav.mean(dim=0, keepdim=True) |
|
if sr0 != sr: |
|
wav = ta.functional.resample(wav, sr0, sr) |
|
sr0 = sr |
|
return wav, sr0 |
|
|
|
def human_time(sec: float) -> str: |
|
m, s = divmod(int(sec), 60) |
|
h, m = divmod(m, 60) |
|
if h: return f"{h}h {m}m {s}s" |
|
if m: return f"{m}m {s}s" |
|
return f"{s}s" |
|
|
|
class Processor: |
|
def __init__( |
|
self, |
|
input_dir: Path, |
|
output_dir: Path, |
|
sr: int, |
|
seconds: float, |
|
feature: str, |
|
n_fft: int, |
|
hop: int, |
|
mel_bins: int, |
|
awgn: float, |
|
bg_dir: Optional[Path], |
|
snr: Optional[float], |
|
time_shift_ms: int, |
|
teach: bool, |
|
): |
|
self.input = input_dir |
|
self.output = output_dir |
|
self.sr = sr |
|
self.seconds = seconds |
|
self.feature = feature |
|
self.n_fft = n_fft |
|
self.hop = hop |
|
self.mel_bins = mel_bins |
|
self.awgn = awgn |
|
self.bg_dir = bg_dir |
|
self.snr = snr |
|
self.time_shift_ms = time_shift_ms |
|
self.teach = teach |
|
|
|
self.spec = T.Spectrogram(n_fft=n_fft, hop_length=hop, power=None, normalized=True) |
|
self.to_db = T.AmplitudeToDB() |
|
self.mel = T.MelSpectrogram(sample_rate=sr, n_fft=n_fft, hop_length=hop, n_mels=mel_bins) |
|
|
|
self.bg_files = list_audio(bg_dir) if bg_dir else [] |
|
self.counts: Dict[str, int] = {} |
|
self.total_out_sec = 0.0 |
|
|
|
def _explain_plan(self): |
|
if not self.teach: return |
|
txt = ( |
|
"[bold]Plan[/bold]\n" |
|
f"- Load (normalized to [-1,1]) → resample to [ok]{self.sr} Hz[/ok]\n" |
|
f"- Pad/Trim to [ok]{self.seconds:.2f}s[/ok]\n" |
|
f"- Optional augments: AWGN={self.awgn}, BG={'on' if self.bg_dir else 'off'}" |
|
f"{f' @ {self.snr} dB' if self.bg_dir else ''}, TimeShift=±{self.time_shift_ms}ms\n" |
|
f"- Export feature: [ok]{self.feature}[/ok] (plus WAV)\n" |
|
) |
|
console.print(Panel(txt, title="[ Teaching ]", title_align="left", style="muted")) |
|
|
|
def _feature_paths(self, in_path: Path) -> Tuple[Path, Optional[Path]]: |
|
rel = in_path.relative_to(self.input) |
|
out_wav = (self.output / "audio" / rel).with_suffix(".wav") |
|
out_feat = None |
|
if self.feature == "spectrogram": |
|
out_feat = (self.output / "features_spectrogram" / rel).with_suffix(".pt") |
|
elif self.feature == "mel": |
|
out_feat = (self.output / "features_mel" / rel).with_suffix(".pt") |
|
return out_wav, out_feat |
|
|
|
def _teach_step(self, title: str, why: str): |
|
if not self.teach: return |
|
console.print(f"[title]{title}[/title] — {why}") |
|
|
|
def _save_audio(self, wav: torch.Tensor, path: Path): |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
ta.save(str(path), wav, self.sr) |
|
|
|
def _save_tensor(self, t: torch.Tensor, path: Path): |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
torch.save(t.cpu(), path) |
|
|
|
def process_one(self, path: Path): |
|
label = class_of(self.input, path) |
|
# Load |
|
self._teach_step("Load", "torchaudio.load → float [-1,1]; convert to mono; resample if needed") |
|
wav, _ = load_mono(path, self.sr) |
|
|
|
# Pad/Trim |
|
self._teach_step("Pad/Trim", f"force fixed length {self.seconds:.2f}s so the model sees consistent shapes") |
|
target_len = int(self.sr * self.seconds) |
|
wav = pad_trim(wav, target_len) |
|
|
|
# Time shift |
|
if self.time_shift_ms: |
|
self._teach_step("Time Shift", f"small circular shift (±{self.time_shift_ms}ms) to reduce overfitting") |
|
wav = time_shift(wav, self.sr, self.time_shift_ms) |
|
|
|
# Background mix |
|
if self.bg_dir and self.snr is not None and self.bg_files: |
|
self._teach_step("BG Mix", f"mix a random background at {self.snr} dB SNR for robustness") |
|
bg_path = choose_bg(self.bg_files) |
|
if bg_path: |
|
bg, _ = load_mono(bg_path, self.sr) |
|
bg = pad_trim(bg, target_len) |
|
wav = mix_bg(wav, bg, self.snr) |
|
|
|
# AWGN |
|
if self.awgn > 0: |
|
self._teach_step("AWGN", f"add small Gaussian noise (σ={self.awgn}) to encourage generalization") |
|
wav = torch.clamp(wav + torch.randn_like(wav) * self.awgn, -1.0, 1.0) |
|
|
|
# Save outputs |
|
out_wav, out_feat = self._feature_paths(path) |
|
self._teach_step("Save WAV", f"write standardized audio → {out_wav}") |
|
self._save_audio(wav, out_wav) |
|
|
|
if out_feat is not None: |
|
if self.feature == "spectrogram": |
|
self._teach_step("Spectrogram", "STFT magnitude → dB (human-like scale)") |
|
spec = self.spec(wav).abs() |
|
spec_db = self.to_db(spec) |
|
self._save_tensor(spec_db, out_feat) |
|
elif self.feature == "mel": |
|
self._teach_step("Mel Spectrogram", f"{self.mel_bins} mel bands → dB") |
|
mel = self.mel(wav) |
|
mel_db = self.to_db(mel) |
|
self._save_tensor(mel_db, out_feat) |
|
|
|
self.counts[label] = self.counts.get(label, 0) + 1 |
|
self.total_out_sec += self.seconds |
|
|
|
def run(self): |
|
files = list_audio(self.input) |
|
if not files: |
|
console.print(f"[err]No audio files under {self.input}") |
|
sys.exit(2) |
|
|
|
classes = sorted({class_of(self.input, p) for p in files}) |
|
t = Table(title="Discovered Classes", title_style="title") |
|
t.add_column("#", justify="right"); t.add_column("Class") |
|
for i,c in enumerate(classes, 1): t.add_row(str(i), c) |
|
console.print(t) |
|
|
|
self._explain_plan() |
|
|
|
with Progress( |
|
"[progress.percentage]{task.percentage:>3.0f}%", |
|
BarColumn(), MofNCompleteColumn(), TimeElapsedColumn(), TimeRemainingColumn(), |
|
console=console, transient=False |
|
) as progress: |
|
task = progress.add_task("[ok]Processing", total=len(files)) |
|
for f in files: |
|
try: |
|
self.process_one(f) |
|
except Exception as e: |
|
console.print(f"[warn]Skipping {f.name}: {e}") |
|
progress.advance(task) |
|
|
|
# Summary on screen only |
|
tt = Table(title="Class Counts", title_style="title") |
|
tt.add_column("Class"); tt.add_column("Count", justify="right") |
|
for c in sorted(self.counts): tt.add_row(c, str(self.counts[c])) |
|
console.print(tt) |
|
console.print(Panel.fit( |
|
f"[ok]Done[/ok] • {sum(self.counts.values())} files • standardized {human_time(self.total_out_sec)} of audio\nOutput: {self.output}", |
|
style="ok" |
|
)) |
|
|
|
def main(argv=None): |
|
ap = argparse.ArgumentParser( |
|
prog="audioprep", |
|
description="Build a standardized audio dataset with on-screen teaching." |
|
) |
|
ap.add_argument("-i","--input", type=Path, help="Input root with class subfolders") |
|
ap.add_argument("-o","--output", type=Path, help="Output directory") |
|
ap.add_argument("--sr", type=int, default=16000, help="Target sample rate") |
|
ap.add_argument("--seconds", type=float, default=2.0, help="Fixed clip length (sec)") |
|
ap.add_argument("--feature", choices=["waveform","spectrogram","mel"], default="mel", help="Feature to export (besides WAV)") |
|
ap.add_argument("--n-fft", type=int, default=1024, help="n_fft for STFT/Mel") |
|
ap.add_argument("--hop", type=int, default=512, help="hop length") |
|
ap.add_argument("--mel-bins", type=int, default=64, help="mel bins for mel feature") |
|
ap.add_argument("--awgn", type=float, default=0.0, help="Gaussian noise stddev") |
|
ap.add_argument("--bg-dir", type=Path, help="Background audio dir (for mixing)") |
|
ap.add_argument("--snr", type=float, help="SNR dB when mixing background (requires --bg-dir)") |
|
ap.add_argument("--time-shift-ms", type=int, default=0, help="± time shift (ms)") |
|
ap.add_argument("--no-teach", action="store_true", help="Disable on-screen explanations") |
|
args = ap.parse_args(argv) |
|
|
|
# Interactive fill if -i/-o missing |
|
if not args.input or not args.output: |
|
from rich.prompt import Prompt, IntPrompt, FloatPrompt |
|
console.print(Panel.fit("Interactive mode — press Enter for defaults", style="title")) |
|
if not args.input: args.input = Path(Prompt.ask("Input dir", default=str(Path.cwd()))) |
|
if not args.output: args.output = Path(Prompt.ask("Output dir", default=str(Path.cwd()/ "audioprep_output"))) |
|
args.sr = IntPrompt.ask("Sample rate", default=args.sr) |
|
args.seconds = FloatPrompt.ask("Clip length (sec)", default=args.seconds) |
|
from rich.prompt import Prompt as P2 |
|
args.feature = P2.ask("Feature", choices=["waveform","spectrogram","mel"], default=args.feature) |
|
if args.feature in ("spectrogram","mel"): |
|
args.n_fft = IntPrompt.ask("n_fft", default=args.n_fft) |
|
args.hop = IntPrompt.ask("hop length", default=args.hop) |
|
if args.feature == "mel": |
|
args.mel_bins = IntPrompt.ask("mel bins", default=args.mel_bins) |
|
args.no_teach = False |
|
|
|
inp = args.input.expanduser().resolve() |
|
out = args.output.expanduser().resolve() |
|
if not inp.exists(): |
|
console.print(f"[err]Input not found: {inp}"); sys.exit(2) |
|
if args.bg_dir: |
|
args.bg_dir = args.bg_dir.expanduser().resolve() |
|
if not args.bg_dir.exists(): |
|
console.print(f"[warn]BG dir not found, disabling: {args.bg_dir}") |
|
args.bg_dir, args.snr = None, None |
|
|
|
proc = Processor( |
|
input_dir=inp, output_dir=out, |
|
sr=args.sr, seconds=args.seconds, |
|
feature=args.feature, n_fft=args.n_fft, hop=args.hop, mel_bins=args.mel_bins, |
|
awgn=args.awgn, bg_dir=args.bg_dir, snr=args.snr, time_shift_ms=args.time_shift_ms, |
|
teach=not args.no_teach |
|
) |
|
try: |
|
proc.run() |
|
except KeyboardInterrupt: |
|
console.print("\n[warn]Interrupted. Partial output kept.") |
|
sys.exit(1) |
|
|
|
if __name__ == "__main__": |
|
main() |