-
-
Save fsodogandji/c5b0ff2b97ed51f038cb3abfb614b6e6 to your computer and use it in GitHub Desktop.
Example of k-fold cross validation with PyTorch Lightning Datamodule
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning import LightningDataModule | |
from torch_geometric.datasets import TUDataset | |
from torch_geometric.data import DataLoader | |
from sklearn.model_selection import KFold | |
class ProteinsKFoldDataModule(LightningDataModule): | |
def __init__( | |
self, | |
data_dir: str = "data/", | |
k: int = 1, # fold number | |
split_seed: int = 12345, # split needs to be always the same for correct cross validation | |
num_splits: int = 10, | |
batch_size: int = 32, | |
num_workers: int = 0, | |
pin_memory: bool = False | |
): | |
super().__init__() | |
# this line allows to access init params with 'self.hparams' attribute | |
self.save_hyperparameters(logger=False) | |
# num_splits = 10 means our dataset will be split to 10 parts | |
# so we train on 90% of the data and validate on 10% | |
assert 1 <= self.k <= self.num_splits, "incorrect fold number" | |
# data transformations | |
self.transforms = None | |
self.data_train: Optional[Dataset] = None | |
self.data_val: Optional[Dataset] = None | |
@property | |
def num_node_features() -> int: | |
return 4 | |
@property | |
def num_classes() -> int: | |
return 2 | |
def setup(self, stage=None): | |
if not self.data_train and not self.data_val: | |
dataset_full = TUDataset(self.hparams.data_dir, name="PROTEINS", use_node_attr=True, transform=self.transforms) | |
# choose fold to train on | |
kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed) | |
all_splits = [k for k in kf.split(dataset_full)] | |
train_indexes, val_indexes = all_splits[self.hparams.k] | |
train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist() | |
self.data_train, self.data_val = dataset_full[train_indexes], dataset_full[val_indexes] | |
def train_dataloader(self): | |
return DataLoader(dataset=self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, shuffle=True) | |
def val_dataloader(self): | |
return DataLoader(dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment