Last active
December 11, 2018 13:40
-
-
Save pedrohbtp/4ac8b6e470d86c42d73875189de21a9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pandas as pd | |
from torch.utils.data import Dataset, DataLoader | |
class ExampleDataset(Dataset): | |
"""Example Dataset""" | |
def __init__(self, csv_file): | |
""" | |
csv_file (string): Path to the csv file containing data. | |
""" | |
self.data_frame = pd.read_csv(csv_file) | |
def __len__(self): | |
return len(self.data_frame) | |
def __getitem__(self, idx): | |
return self.data_frame[idx] | |
# instantiates the dataset | |
example_dataset = ExampleDataset('my_data_file.csv') | |
# batch size: number of samples returned per iteration | |
# shuffle: Flag to shuffle the data before reading so you don't read always in the same order | |
# num_workers: used to load the data in parallel | |
example_data_loader = DataLoader(example_dataset, , batch_size=4, shuffle=True, num_workers=4) | |
# Loops over the data 4 samples at a time | |
for batch_index, batch in enumerate(example_data_loader): | |
print(batch_index, batch) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment