Skip to content

Instantly share code, notes, and snippets.

@pablosjv
Created August 27, 2021 11:58
Show Gist options
  • Save pablosjv/ef619b1a38027ca0ce54a6a93f4371af to your computer and use it in GitHub Desktop.
Save pablosjv/ef619b1a38027ca0ce54a6a93f4371af to your computer and use it in GitHub Desktop.
Large Scale Pytorch Inference Pipeline: Spark vs Dask - Code Examples
from collections import namedtuple
from torch.utils.data import Dataset
Tokens = namedtuple("Tokens", ["input_ids", "attention_mask"])
class TokensDataset(Dataset):
def __init__(self, iids, amask):
self.input_ids = iids.to_numpy()
self.attention_mask = amask.to_numpy()
def __len__(self):
return len(self.input_ids)
def __getitem__(self, index):
input_ids = eval(self.input_ids[index])
amask = eval(self.attention_mask[index])
input_ids = np.array(np.pad(input_ids,
pad_width=[0, 512 - len(input_ids)],
mode='constant',
constant_values=[0]))
amask = np.array(np.pad(amask,
pad_width=[0, 512 - len(amask)],
mode='constant',
constant_values=[0]))
return Tokens(input_ids, amask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment