Created
September 6, 2020 12:30
-
-
Save MLWhiz/77323a18d92e755858bc5ce01f04360c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class CustomTextDataset(Dataset): | |
| ''' | |
| Simple Dataset initializes with X and y vectors | |
| We start by sorting our X and y vectors by sequence lengths | |
| ''' | |
| def __init__(self,X,y=None): | |
| self.data = list(zip(X,y)) | |
| # Sort by length of first element in tuple | |
| self.data = sorted(self.data, key=lambda x: len(x[0])) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment