Created
September 15, 2017 07:53
-
-
Save dhpollack/f33ff65fa0264d1a332f36853e6f07a4 to your computer and use it in GitHub Desktop.
A minimal example of a variable length dataset for pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.utils.data as data\n", | |
"import torchaudio\n", | |
"import librosa\n", | |
"import numpy as np\n", | |
"import random\n", | |
"import os\n", | |
"import glob" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
" class VariableLengthDataset(data.Dataset):\n", | |
" def __init__(self, manifest, snippet_length=24000, get_sequentially=False, ret_np=False, use_librosa=False):\n", | |
" self.manifest = manifest\n", | |
" self.snippet_length = snippet_length\n", | |
" self.get_sequentially = get_sequentially\n", | |
" self.use_librosa = use_librosa\n", | |
" self.ret_np = ret_np\n", | |
" self.acc = 0\n", | |
" self.snippet_counter = 0\n", | |
" self.audio_idx = 0\n", | |
" self.st = 0\n", | |
" self.data = {}\n", | |
" def __getitem__(self, index):\n", | |
" # load audio data from file or cache\n", | |
" if self.snippet_counter == 0:\n", | |
" self.audio_idx = index - self.acc\n", | |
" apath = self.manifest[self.audio_idx]\n", | |
" if apath not in self.data:\n", | |
" if self.use_librosa:\n", | |
" sig, sr = librosa.core.load(apath, sr=None)\n", | |
" sig = torch.from_numpy(sig).unsqueeze(1).float()\n", | |
" else:\n", | |
" sig, sr = torchaudio.load(apath, normalization=True)\n", | |
" self.data[apath] = (sig, sr)\n", | |
" else:\n", | |
" sig, sr = self.data[apath]\n", | |
"\n", | |
" # increase iterations based on length of audio\n", | |
" num_snippets = int(sig.size(0) // self.snippet_length)\n", | |
" self.acc += max(num_snippets-1,0)\n", | |
" else:\n", | |
" apath = self.manifest[self.audio_idx]\n", | |
" sig, sr = self.data[apath]\n", | |
" num_snippets = int(sig.size(0) // self.snippet_length)\n", | |
"\n", | |
" # create snippet\n", | |
" if self.get_sequentially:\n", | |
" self.st += self.snippet_length\n", | |
" else:\n", | |
" self.st = random.randrange(int(sig.size(0)-self.snippet_length))\n", | |
" ret_sig = sig[self.st:(self.st+self.snippet_length)]\n", | |
" if self.ret_np:\n", | |
" ret_sig = ret_sig.numpy()\n", | |
"\n", | |
" # update counter for current audio file\n", | |
" self.snippet_counter += 1\n", | |
"\n", | |
" # label creation\n", | |
" spkr = os.path.dirname(apath).rsplit(\"/\", 1)[-1]\n", | |
"\n", | |
" # check for reset\n", | |
" if self.snippet_counter >= num_snippets:\n", | |
" self.snippet_counter = 0\n", | |
" self.st = 0\n", | |
"\n", | |
" return ret_sig, spkr\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.data) + self.acc\n", | |
"\n", | |
" def reset_acc(self):\n", | |
" self.acc = 0\n", | |
"\n", | |
" datadir = \"pcsnpny-20150204-mkj\"\n", | |
" audio_manifest = [a for a in glob.glob(datadir+\"/**/*.wav\", recursive=True)]\n", | |
" def run_dataset():\n", | |
" for epoch in range(1):\n", | |
" print(epoch)\n", | |
" all_data = [(x, label) for x, label in ds]\n", | |
" ds.reset_acc()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0\n", | |
"CPU times: user 19.5 ms, sys: 3.57 ms, total: 23 ms\n", | |
"Wall time: 16.2 ms\n", | |
"0\n", | |
"CPU times: user 72.1 ms, sys: 22 µs, total: 72.1 ms\n", | |
"Wall time: 36 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=False)\n", | |
"%time run_dataset()\n", | |
"ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=True)\n", | |
"%time run_dataset()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment