Skip to content

Instantly share code, notes, and snippets.

@mrm8488
Created August 10, 2020 10:21
Show Gist options
  • Save mrm8488/802b99d9d2ab2115e569100e187b0fd2 to your computer and use it in GitHub Desktop.
Save mrm8488/802b99d9d2ab2115e569100e187b0fd2 to your computer and use it in GitHub Desktop.
class NlpRawTextDataset(Dataset):
def __init__(self, tokenizer, file_path: str, block_size: int):
self.tokenizer = tokenizer
self.file_path = file_path
self.block_size = block_size
print("Loading Dataset...")
self.dataset = load_dataset("text", data_files=file_path)["train"]
print("Loaded Dataset!")
self.len = len(self.dataset)
def __len__(self):
return self.len
def preprocess(self, text):
batch_encoding = self.tokenizer(str(text), add_special_tokens=True, truncation=True, max_length=self.block_size)
return torch.tensor(batch_encoding["input_ids"])
def __getitem__(self, i):
phrase = self.dataset[i]
example = self.preprocess(phrase)
return example
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment