Skip to content

Instantly share code, notes, and snippets.

@erap129
Last active September 28, 2021 16:39
Show Gist options
  • Select an option

  • Save erap129/e9733e4db3813310b28a60aceabc829b to your computer and use it in GitHub Desktop.

Select an option

Save erap129/e9733e4db3813310b28a60aceabc829b to your computer and use it in GitHub Desktop.
NASA RUL project - pytorch data loader
from torch.utils.data import TensorDataset
class RULDataModule(pl.LightningDataModule):
def __init__(self, X_train, y_train, X_val, y_val, X_test, y_test,
batch_size):
super().__init__()
self.X_train = X_train
self.y_train = y_train
self.X_val = X_val
self.y_val = y_val
self.X_test = X_test
self.y_test = y_test
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = TensorDataset(torch.Tensor(self.X_train), torch.Tensor(self.y_train))
self.val_dataset = TensorDataset(torch.Tensor(self.X_val), torch.Tensor(self.y_val))
self.test_dataset = TensorDataset(torch.Tensor(self.X_test), torch.Tensor(self.y_test))
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=cpu_count())
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=cpu_count())
def test_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=cpu_count())
n_epochs = 65
batch_size = 64
X_train_lstm, X_val_lstm, y_train_lstm, y_val_lstm = train_test_split(
X_train_rolling, y_train_rolling, test_size=0.2, stratify=y_train_rolling)
data_module = RULDataModule(X_train_lstm, y_train_lstm,
X_val_lstm, y_val_lstm,
X_test_rolling, y_test, batch_size=batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment