Skip to content

Instantly share code, notes, and snippets.

@ab3llini
Last active July 23, 2020 10:58
Show Gist options
  • Save ab3llini/511ea6d7cd1888e6bfc6b658350f8133 to your computer and use it in GitHub Desktop.
Save ab3llini/511ea6d7cd1888e6bfc6b658350f8133 to your computer and use it in GitHub Desktop.
TweetDataset
import torch
from torch.utils.data import Dataset
import json
class TweetDataset(Dataset):
def __init__(self, path, device):
self.device = device
# Load the JSON file containing our pre-processed data
with open(path, 'r') as fp:
self.data = json.load(fp)
def __len__(self):
return len(self.data)
def __getitem__(self, item):
# Return four different tensors (3 inputs and the label)
ids, att, tokens, labels = list(map(lambda o: torch.tensor(o).to(self.device), self.data[item]))
return ids.long(), att.float(), tokens.long(), labels.long()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment