Skip to content

Instantly share code, notes, and snippets.

@fsodogandji
Forked from ashleve/kfold_example.py
Created November 15, 2023 09:01
Show Gist options
  • Save fsodogandji/c5b0ff2b97ed51f038cb3abfb614b6e6 to your computer and use it in GitHub Desktop.
Save fsodogandji/c5b0ff2b97ed51f038cb3abfb614b6e6 to your computer and use it in GitHub Desktop.
Example of k-fold cross validation with PyTorch Lightning Datamodule
from pytorch_lightning import LightningDataModule
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from sklearn.model_selection import KFold
class ProteinsKFoldDataModule(LightningDataModule):
def __init__(
self,
data_dir: str = "data/",
k: int = 1, # fold number
split_seed: int = 12345, # split needs to be always the same for correct cross validation
num_splits: int = 10,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
self.save_hyperparameters(logger=False)
# num_splits = 10 means our dataset will be split to 10 parts
# so we train on 90% of the data and validate on 10%
assert 1 <= self.k <= self.num_splits, "incorrect fold number"
# data transformations
self.transforms = None
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
@property
def num_node_features() -> int:
return 4
@property
def num_classes() -> int:
return 2
def setup(self, stage=None):
if not self.data_train and not self.data_val:
dataset_full = TUDataset(self.hparams.data_dir, name="PROTEINS", use_node_attr=True, transform=self.transforms)
# choose fold to train on
kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
all_splits = [k for k in kf.split(dataset_full)]
train_indexes, val_indexes = all_splits[self.hparams.k]
train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()
self.data_train, self.data_val = dataset_full[train_indexes], dataset_full[val_indexes]
def train_dataloader(self):
return DataLoader(dataset=self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory, shuffle=True)
def val_dataloader(self):
return DataLoader(dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment