Skip to content

Instantly share code, notes, and snippets.

@lextoumbourou
Created August 20, 2018 09:18
Show Gist options
  • Save lextoumbourou/8f90313cbc3598ffbabeeaa1741a11c8 to your computer and use it in GitHub Desktop.
Save lextoumbourou/8f90313cbc3598ffbabeeaa1741a11c8 to your computer and use it in GitHub Desktop.
Torchtext dataset from DataFrame
from torchtext import data
class DataFrameDataset(data.Dataset):
def __init__(self, df, text_field, label_field, is_test=False, **kwargs):
fields = [('text', text_field), ('label', label_field)]
examples = []
for i, row in df.iterrows():
label = row.sentiment if not is_test else None
text = row.text
examples.append(data.Example.fromlist([text, label], fields))
super().__init__(examples, fields, **kwargs)
@staticmethod
def sort_key(ex):
return len(ex.text)
@classmethod
def splits(cls, text_field, label_field, train_df, val_df=None, test_df=None, **kwargs):
train_data, val_data, test_data = (None, None, None)
if train_df is not None:
train_data = cls(train_df.copy(), text_field, label_field, **kwargs)
if val_df is not None:
val_data = cls(val_df.copy(), text_field, label_field, **kwargs)
if test_df is not None:
test_data = cls(test_df.copy(), text_field, label_field, True, **kwargs)
return tuple(d for d in (train_data, val_data, test_data) if d is not None)
train_ds, val_ds, test_ds = DataFrameDataset.splits(
text_field=TEXT_FIELD, label_field=LABEL_FIELD, train_df=train_df, val_df=val_df, test_df=test_df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment