Last active
June 3, 2019 07:47
-
-
Save georgepar/09ea1ca5e9933fd52840c663ae41245b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import requests | |
from torch.utils.data import Dataset | |
from silx.io.dictdump import h5todict | |
def download_file(url, fname): | |
resp = requests.get(url, stream=True) | |
with open(fname, 'wb') as fd: | |
for datum in resp.iter_content(): | |
fd.write(datum) | |
return fname | |
class CMUMosi(Dataset): | |
"""DataLoader for CMU Mosi dataset. Dataloader for CMU MOSEI should be very similar | |
Need to also download raw wavs and map the file ids | |
No alignment is performed | |
Args: | |
wav_paths (str): path to wavs | |
""" | |
def __init__(self, wavs_path, download_path='./cmu_mosei', task='sentiment'): | |
self.wavs_path = wavs_path | |
self.task = task | |
label_urls = { | |
"sentiment": "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsSentiment.csd", | |
"emotion": "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsEmotions.csd" | |
} | |
# Replace for cmu mosei | |
# label_urls = { | |
# "sentiment": "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsSentiment.csd", | |
# "emotion"]: "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsEmotions.csd" | |
# } | |
self.label_file = os.path.join(download_path, | |
label_urls[task].split('/')[-1]) | |
if not os.path.isfile(self.label_file): | |
self.label_file = download_file(label_urls[task], self.label_file) | |
label_dict = h5todict(self.label_file)['Opinion Segment Labels']['data'] | |
# Labels are dicts in the form {'features': numpy.array, 'intervals': numpy array} | |
# features contains sentiment annotations while intervals | |
# contains the start and end time of the respective label in the video | |
self.file_ids, self.labels = list(map(list, zip(*label_dict.items()))) | |
self.data = [] # read respective file_ids and extract features | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx], self.label[idx] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment