Skip to content

Instantly share code, notes, and snippets.

@andreaschandra
Created June 6, 2020 07:49
Show Gist options
  • Save andreaschandra/efe4dece47b27946e236f9b85dc8a4d1 to your computer and use it in GitHub Desktop.
Save andreaschandra/efe4dece47b27946e236f9b85dc8a4d1 to your computer and use it in GitHub Desktop.
class DisasterDataset():
def __init__(self, data_path, eval_path, tokenizer):
d_data = pd.read_table(data_path, sep=',')
d_eval = pd.read_table(eval_path, sep=',')
row, col = d_data.shape
d_train = d_data[:int(row * 0.8)]
d_test = d_data[int(row*0.8):]
d_train.reset_index(drop=True, inplace=True)
d_test.reset_index(drop=True, inplace=True)
self.tokenizer = tokenizer
self.dataset = {'train': (d_train, len(d_train)),
'test': (d_test, len(d_test)),
'eval': (d_eval, len(d_eval))}
self.num_labels = len(d_train.target.unique().tolist())
self.set_split('train')
def get_vocab(self):
text = " ".join(self.data.text.tolist())
text = text.lower()
vocab = text.split(" ")
with open('vocab.txt', 'w') as file:
for word in vocab:
file.write(word)
file.write('\n')
file.close()
return 'vocab.txt'
def set_split(self, split = 'train'):
self.split = split
self.data, self.length = self.dataset[split]
def __getitem__(self, idx):
x = self.data.loc[idx, "text"].lower()
x = self.tokenizer.encode(x, return_tensors="pt")[0]
if self.split != 'eval':
y = self.data.loc[idx, "target"]
return {'id': idx, 'x': x, 'y': y}
else:
id_ = self.data.loc[idx, "id"]
return {'id': id_, 'x': x}
def __len__(self):
return self.length
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment