Skip to content

Instantly share code, notes, and snippets.

@brunosan
Created October 11, 2024 11:30
Show Gist options
  • Save brunosan/7cce47b6e1ee2a75b5b78cab6f53488f to your computer and use it in GitHub Desktop.
Save brunosan/7cce47b6e1ee2a75b5b78cab6f53488f to your computer and use it in GitHub Desktop.
find . -name "*.py" -print0 | xargs -0 -I {} sh -c 'echo {}: $(cat {})' >> model-all.py.txt
./finetune/classify/classify.py: """ Command line interface to run the neural network model! From the project root directory, do: python classify.py fit --config configs/classify_eurosat.yaml References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from finetune.classify.eurosat_datamodule import EuroSATDataModule # noqa: F401 from finetune.classify.eurosat_model import EuroSATClassifier # noqa: F401 # %% def cli_main(): """ Command-line inteface to run Clasifier model with EuroSATDataModule. """ cli = LightningCLI(EuroSATClassifier, EuroSATDataModule) return cli # %% if __name__ == "__main__": cli_main() print("Done!")
./finetune/classify/factory.py: import re import torch from torch import nn from src.model import Encoder class Classifier(nn.Module): """ Classifier class uses Clay Encoder for feature extraction and a head for classification. Attributes: clay_encoder (Encoder): The encoder for feature extraction. head (nn.Sequential): The head for classification. device (torch.device): The device to run the model on. """ def __init__(self, num_classes=10, ckpt_path=None): """ Initialize the Classifier. Args: num_classes (int, optional): The number of classes for classification. Defaults to 10. ckpt_path (str, optional): Clay MAE pretrained model checkpoint path. Defaults to None. """ super().__init__() # Initialize Clay Encoder with parameters from base model. Set # mask_ratio to 0.0 & shuffle to False for downstream tasks. self.clay_encoder = Encoder( mask_ratio=0.0, patch_size=8, shuffle=False, dim=768, depth=12, heads=12, dim_head=64, mlp_ratio=4.0, ) # Simple 2 layer MLP head for classification self.head = nn.Sequential( nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.25), nn.Linear(512, num_classes), ) # Determine the device to run the model on self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) # Load Clay MAE pretrained weights for the Encoder self.load_clay_weights(ckpt_path) def load_clay_weights(self, ckpt_path): """ Load the weights for Clay MAE Encoder from a checkpoint file. Args: ckpt_path (str): Clay MAE pretrained model checkpoint path. """ # Load the checkpoint file ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt.get("state_dict") # Remove model.encoder prefix for the clay encoder state_dict = { re.sub(r"^model\.encoder\.", "", name): param for name, param in state_dict.items() if name.startswith("model.encoder") } # Copy the weights from the state dict to the encoder for name, param in self.clay_encoder.named_parameters(): if name in state_dict and param.size() == state_dict[name].size(): param.data.copy_(state_dict[name]) # Copy the weights else: print(f"No matching parameter for {name} with size {param.size()}") # Freeze clay encoder for param in self.clay_encoder.parameters(): param.requires_grad = False # Set the encoder to evaluation mode self.clay_encoder.eval() def forward(self, datacube): """ Forward pass of the Classifier. Args: datacube (torch.Tensor): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The output logits. """ # Get the embeddings from the encoder embeddings, *_ = self.clay_encoder( datacube ) # embeddings: batch x (1 + row x col) x 768 # Use only the first embedding i.e cls token embeddings = embeddings[:, 0, :] # Pass the embeddings through the head to get the logits logits = self.head(embeddings) return logits
./finetune/classify/eurosat_datamodule.py: import lightning as L import torch import yaml from box import Box from torch.utils.data import DataLoader from torchgeo.datasets import EuroSAT as TGEuroSAT from torchvision.transforms import v2 S2_BANDS = [ "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B08A", "B11", "B12", ] class EuroSAT(TGEuroSAT): """ Subclass of TGEuroSAT to customize the dataset loading and transformations. Args: root (str): Root directory of the dataset. split (str): Dataset split to use ('train' or 'val'). bands (list): List of spectral bands to use. transforms (callable): Transformations to apply to the samples. download (bool): If true, downloads the dataset. """ def __init__(self, root, split, bands, transforms, download): super().__init__(root, split, bands, transforms, download) def __getitem__(self, index): """ Override the __getitem__ method to apply custom transformations. Args: index (int): Index of the sample to retrieve. Returns: dict: A dictionary containing the image tensor, label, and additional metadata. """ image, label = self._load_image(index) image = torch.index_select(image, dim=0, index=self.band_indices).float() sample = { "pixels": image, "label": label, "time": torch.zeros(4), # Placeholder for time information "latlon": torch.zeros(4), # Placeholder for lat/lon information } if self.transforms is not None: sample = self.transforms(sample) return sample class EuroSATDataModule(L.LightningDataModule): """ Data module for loading and transforming the EuroSAT dataset. Args: batch_size (int): Batch size for the dataloaders. num_workers (int): Number of workers for data loading. metadata_path (str): Path to the metadata file for normalization statistics. """ def __init__(self, batch_size, num_workers, metadata_path): super().__init__() self.batch_size = batch_size self.num_workers = num_workers metadata = Box(yaml.safe_load(open(metadata_path)))["sentinel-2-l2a"] mean = list(metadata.bands.mean.values()) std = list(metadata.bands.std.values()) self.trn_tfm = v2.Compose( [ v2.RandomHorizontalFlip(), v2.RandomVerticalFlip(), v2.Normalize(mean, std), ] ) self.val_tfm = v2.Compose([v2.Normalize(mean, std)]) def setup(self, stage=None): """ Setup the datasets for training and validation. Args: stage (str): Stage of the training process ('fit', 'validate', etc.). """ if stage in {"fit", None}: self.trn_ds = EuroSAT( root="data", split="train", bands=S2_BANDS, transforms=self.trn_tfm, download=True, ) self.val_ds = EuroSAT( root="data", split="val", bands=S2_BANDS, transforms=self.val_tfm, download=True, ) def train_dataloader(self): """ Returns the DataLoader for the training dataset. Returns: DataLoader: DataLoader for the training dataset. """ return DataLoader( self.trn_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, prefetch_factor=2, ) def val_dataloader(self): """ Returns the DataLoader for the validation dataset. Returns: DataLoader: DataLoader for the validation dataset. """ return DataLoader( self.val_ds, batch_size=self.batch_size CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 2, shuffle=False, num_workers=self.num_workers, pin_memory=True, prefetch_factor=2, )
./finetune/classify/eurosat_model.py: import lightning as L import torch from torch import nn, optim from torchmetrics import Accuracy from finetune.classify.factory import Classifier class EuroSATClassifier(L.LightningModule): """ LightningModule for training and evaluating a classifier on the EuroSAT dataset. Args: num_classes (int): Number of classes for classification. ckpt_path (str): Clay MAE pretrained checkpoint path. lr (float): Learning rate for the optimizer. wd (float): Weight decay for the optimizer. b1 (float): Beta1 parameter for the Adam optimizer. b2 (float): Beta2 parameter for the Adam optimizer. """ def __init__(self, num_classes, ckpt_path, lr, wd, b1, b2): # noqa: PLR0913 super().__init__() self.save_hyperparameters() self.model = Classifier(num_classes=num_classes, ckpt_path=ckpt_path) self.loss_fn = nn.CrossEntropyLoss() self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) def forward(self, datacube): """ Forward pass through the classifier. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The output logits from the classifier. """ # Wavelengths for Sentinel 2 bands of EuroSAT dataset waves = torch.tensor( [0.493, 0.56, 0.665, 0.704, 0.74, 0.783, 0.842, 0.865, 1.61, 2.19] ) gsd = torch.tensor(10.0) return self.model( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, } ) def configure_optimizers(self): """ Configure the optimizer and learning rate scheduler. Returns: dict: A dictionary containing the optimizer and learning rate scheduler. """ optimizer = optim.AdamW( [ param for name, param in self.model.named_parameters() if param.requires_grad ], lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=100, T_mult=1, eta_min=self.hparams.lr CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 100, last_epoch=-1 ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", }, } def shared_step(self, batch, batch_idx, phase): """ Perform a shared step for both training and validation. Args: batch (dict): A batch of data. batch_idx (int): The index of the batch. phase (str): The phase ('train' or 'val'). Returns: torch.Tensor: The computed loss for the batch. """ labels = batch["label"].long() logits = self(batch) loss = self.loss_fn(logits, labels) score = self.accuracy(logits, labels) self.log( f"{phase}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/score", score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch, batch_idx): """ Perform a training step. Args: batch (dict): A batch of training data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): """ Perform a validation step. Args: batch (dict): A batch of validation data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "val")
./finetune/segment/preprocess_data.py: r""" Chesapeake CVPR Data Processing Script ====================================== This script processes GeoTIFF files from the Chesapeake CVPR dataset to create image chips for segmentation tasks. Dataset Source: --------------- Chesapeake CVPR data from LILA: https://lila.science/datasets/chesapeakelandcover For this experiment, we will use images from NY. Notes: ------ 1. Only copy *_lc.tif & *_naip-new.tif files that we will use for our segmentation downstream task. Using s5cmd for this: https://github.com/peak/s5cmd - Train: s5cmd cp \ --no-sign-request \ --include "*_lc.tif" \ --include "*_naip-new.tif" \ "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" \ data/cvpr/files/train/ - Val: s5cmd cp \ --no-sign-request \ --include "*_lc.tif" \ --include "*_naip-new.tif" \ "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" \ data/cvpr/files/val/ 2. We will create chips of size `224 x 224` to feed them to the model, feel free to experiment with other chip sizes as well. Run the script as follows: python preprocess_data.py <data_dir> <output_dir> <chip_size> Example: python preprocess_data.py data/cvpr/files data/cvpr/ny 224 """ # noqa E501 import os import sys from pathlib import Path import numpy as np import rasterio as rio def read_and_chip(file_path, chip_size, output_dir): """ Reads a GeoTIFF file, creates chips of specified size, and saves them as numpy arrays. Args: file_path (str or Path): Path to the GeoTIFF file. chip_size (int): Size of the square chips. output_dir (str or Path): Directory to save the chips. """ os.makedirs(output_dir, exist_ok=True) with rio.open(file_path) as src: data = src.read() n_chips_x = src.width // chip_size n_chips_y = src.height // chip_size chip_number = 0 for i in range(n_chips_x): for j in range(n_chips_y): x1, y1 = i CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py chip_size, j CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py chip_size x2, y2 = x1 + chip_size, y1 + chip_size chip = data[:, y1:y2, x1:x2] chip_path = os.path.join( output_dir, f"{Path(file_path).stem}_chip_{chip_number}.npy", ) np.save(chip_path, chip) chip_number += 1 def process_files(file_paths, output_dir, chip_size): """ Processes a list of files, creating chips and saving them. Args: file_paths (list of Path): List of paths to the GeoTIFF files. output_dir (str or Path): Directory to save the chips. chip_size (int): Size of the square chips. """ for file_path in file_paths: print(f"Processing: {file_path}") read_and_chip(file_path, chip_size, output_dir) def main(): """ Main function to process files and create chips. Expects three command line arguments: - data_dir: Directory containing the input GeoTIFF files. - output_dir: Directory to save the output chips. - chip_size: Size of the square chips. """ if len(sys.argv) != 4: # noqa: PLR2004 print("Usage: python script.py <data_dir> <output_dir> <chip_size>") sys.exit(1) data_dir = Path(sys.argv[1]) output_dir = Path(sys.argv[2]) chip_size = int(sys.argv[3]) train_image_paths = list((data_dir / "train").glob("*_naip-new.tif")) val_image_paths = list((data_dir / "val").glob("*_naip-new.tif")) train_label_paths = list((data_dir / "train").glob("*_lc.tif")) val_label_paths = list((data_dir / "val").glob("*_lc.tif")) process_files(train_image_paths, output_dir / "train/chips", chip_size) process_files(val_image_paths, output_dir / "val/chips", chip_size) process_files(train_label_paths, output_dir / "train/labels", chip_size) process_files(val_label_paths, output_dir / "val/labels", chip_size) if __name__ == "__main__": main()
./finetune/segment/chesapeake_model.py: """ LightningModule for training and validating a segmentation model using the Segmentor class. """ import lightning as L import segmentation_models_pytorch as smp import torch import torch.nn.functional as F from torch import optim from torchmetrics.classification import F1Score, MulticlassJaccardIndex from finetune.segment.factory import Segmentor class ChesapeakeSegmentor(L.LightningModule): """ LightningModule for segmentation tasks, utilizing Clay Segmentor. Attributes: model (nn.Module): Clay Segmentor model. loss_fn (nn.Module): The loss function. iou (Metric): Intersection over Union metric. f1 (Metric): F1 Score metric. lr (float): Learning rate. """ def __init__( # # noqa: PLR0913 self, num_classes, feature_maps, ckpt_path, lr, wd, b1, b2, ): super().__init__() self.save_hyperparameters() # Save hyperparameters for checkpointing self.model = Segmentor( num_classes=num_classes, feature_maps=feature_maps, ckpt_path=ckpt_path, ) self.loss_fn = smp.losses.FocalLoss(mode="multiclass") self.iou = MulticlassJaccardIndex( num_classes=num_classes, average="weighted", ) self.f1 = F1Score( task="multiclass", num_classes=num_classes, average="weighted", ) def forward(self, datacube): """ Forward pass through the segmentation model. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The segmentation logits. """ waves = torch.tensor([0.65, 0.56, 0.48, 0.842]) # NAIP wavelengths gsd = torch.tensor(1.0) # NAIP GSD # Forward pass through the network return self.model( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, }, ) def configure_optimizers(self): """ Configure the optimizer and learning rate scheduler. Returns: dict: A dictionary containing the optimizer and scheduler configuration. """ optimizer = optim.AdamW( [ param for name, param in self.model.named_parameters() if param.requires_grad ], lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=1000, T_mult=1, eta_min=self.hparams.lr CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 100, last_epoch=-1, ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", }, } def shared_step(self, batch, batch_idx, phase): """ Shared step for training and validation. Args: batch (dict): A dictionary containing the batch data. batch_idx (int): The index of the batch. phase (str): The phase (train or val). Returns: torch.Tensor: The loss value. """ labels = batch["label"].long() outputs = self(batch) outputs = F.interpolate( outputs, size=(224, 224), mode="bilinear", align_corners=False, ) # Resize to match labels size loss = self.loss_fn(outputs, labels) iou = self.iou(outputs, labels) f1 = self.f1(outputs, labels) # Log metrics self.log( f"{phase}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/iou", iou, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/f1", f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch, batch_idx): """ Training step for the model. Args: batch (dict): A dictionary containing the batch data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The loss value. """ return self.shared_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): """ Validation step for the model. Args: batch (dict): A dictionary containing the batch data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The loss value. """ return self.shared_step(batch, batch_idx, "val")
./finetune/segment/chesapeake_datamodule.py: """ DataModule for the Chesapeake Bay dataset for segmentation tasks. This implementation provides a structured way to handle the data loading and preprocessing required for training and validating a segmentation model. Dataset citation: Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N. Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data. Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019). Dataset URL: https://lila.science/datasets/chesapeakelandcover """ import re from pathlib import Path import lightning as L import numpy as np import torch import yaml from box import Box from torch.utils.data import DataLoader, Dataset from torchvision.transforms import v2 class ChesapeakeDataset(Dataset): """ Dataset class for the Chesapeake Bay segmentation dataset. Args: chip_dir (str): Directory containing the image chips. label_dir (str): Directory containing the labels. metadata (Box): Metadata for normalization and other dataset-specific details. platform (str): Platform identifier used in metadata. """ def __init__(self, chip_dir, label_dir, metadata, platform): self.chip_dir = Path(chip_dir) self.label_dir = Path(label_dir) self.metadata = metadata self.transform = self.create_transforms( mean=list(metadata[platform].bands.mean.values()), std=list(metadata[platform].bands.std.values()), ) # Load chip and label file names self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npy")] self.labels = [re.sub("_naip-new_", "_lc_", chip) for chip in self.chips] def create_transforms(self, mean, std): """ Create normalization transforms. Args: mean (list): Mean values for normalization. std (list): Standard deviation values for normalization. Returns: torchvision.transforms.Compose: A composition of transforms. """ return v2.Compose( [ v2.Normalize(mean=mean, std=std), ], ) def __len__(self): return len(self.chips) def __getitem__(self, idx): """ Get a sample from the dataset. Args: idx (int): Index of the sample. Returns: dict: A dictionary containing the image, label, and additional information. """ chip_name = self.chip_dir / self.chips[idx] label_name = self.label_dir / self.labels[idx] chip = np.load(chip_name).astype(np.float32) label = np.load(label_name) # Remap labels to match desired classes label_mapping = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 15: 6} remapped_label = np.vectorize(label_mapping.get)(label) sample = { "pixels": self.transform(torch.from_numpy(chip)), "label": torch.from_numpy(remapped_label[0]), "time": torch.zeros(4), # Placeholder for time information "latlon": torch.zeros(4), # Placeholder for latlon information } return sample class ChesapeakeDataModule(L.LightningDataModule): """ DataModule class for the Chesapeake Bay dataset. Args: train_chip_dir (str): Directory containing training image chips. train_label_dir (str): Directory containing training labels. val_chip_dir (str): Directory containing validation image chips. val_label_dir (str): Directory containing validation labels. metadata_path (str): Path to the metadata file. batch_size (int): Batch size for data loading. num_workers (int): Number of workers for data loading. platform (str): Platform identifier used in metadata. """ def __init__( # noqa: PLR0913 self, train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, metadata_path, batch_size, num_workers, platform, ): super().__init__() self.train_chip_dir = train_chip_dir self.train_label_dir = train_label_dir self.val_chip_dir = val_chip_dir self.val_label_dir = val_label_dir self.metadata = Box(yaml.safe_load(open(metadata_path))) self.batch_size = batch_size self.num_workers = num_workers self.platform = platform def setup(self, stage=None): """ Setup datasets for training and validation. Args: stage (str): Stage identifier ('fit' or 'test'). """ if stage in {"fit", None}: self.trn_ds = ChesapeakeDataset( self.train_chip_dir, self.train_label_dir, self.metadata, self.platform, ) self.val_ds = ChesapeakeDataset( self.val_chip_dir, self.val_label_dir, self.metadata, self.platform, ) def train_dataloader(self): """ Create DataLoader for training data. Returns: DataLoader: DataLoader for training dataset. """ return DataLoader( self.trn_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, ) def val_dataloader(self): """ Create DataLoader for validation data. Returns: DataLoader: DataLoader for validation dataset. """ return DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, )
./finetune/segment/factory.py: """ Clay Segmentor for semantic segmentation tasks. Attribution: Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation with Transformers Paper URL: https://arxiv.org/abs/2105.15203 """ import re import torch from einops import rearrange, repeat from torch import nn from src.model import Encoder class SegmentEncoder(Encoder): """ Encoder class for segmentation tasks, incorporating a feature pyramid network (FPN). Attributes: feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the clay checkpoint file. """ def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, feature_maps, ckpt_path=None, ): super().__init__( mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, ) self.feature_maps = feature_maps # Define Feature Pyramid Network (FPN) layers self.fpn1 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), nn.BatchNorm2d(dim), nn.GELU(), nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn5 = nn.Identity() # Set device self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) # Load model from checkpoint if provided self.load_from_ckpt(ckpt_path) def load_from_ckpt(self, ckpt_path): """ Load the model's state from a checkpoint file. Args: ckpt_path (str): The path to the checkpoint file. """ if ckpt_path: # Load checkpoint ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt.get("state_dict") # Prepare new state dict with the desired subset and naming new_state_dict = { re.sub(r"^model\.encoder\.", "", name): param for name, param in state_dict.items() if name.startswith("model.encoder") } # Load the modified state dict into the model model_state_dict = self.state_dict() for name, param in new_state_dict.items(): if ( name in model_state_dict and param.size() == model_state_dict[name].size() ): model_state_dict[name].copy_(param) else: print(f"No matching parameter for {name} with size {param.size()}") # Freeze the loaded parameters for name, param in self.named_parameters(): if name in new_state_dict: param.requires_grad = False def forward(self, datacube): """ Forward pass of the SegmentEncoder. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: list: A list of feature maps extracted from the datacube. """ cube, time, latlon, gsd, waves = ( datacube["pixels"], # [B C H W] datacube["time"], # [B 2] datacube["latlon"], # [B 2] datacube["gsd"], # 1 datacube["waves"], # [N] ) B, C, H, W = cube.shape # Patchify and create embeddings per patch patches, waves_encoded = self.to_patch_embed(cube, waves) # [B L D] patches = self.add_encodings(patches, time, latlon, gsd) # [B L D] # Add class tokens cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] features = [] for idx, (attn, ff) in enumerate(self.transformer.layers): patches = attn(patches) + patches patches = ff(patches) + patches if idx in self.feature_maps: _cube = rearrange( patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8 ) features.append(_cube) patches = self.transformer.norm(patches) _cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8) features.append(_cube) # Apply FPN layers ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4, self.fpn5] for i in range(len(features)): features[i] = ops[i](features[i]) return features class Segmentor(nn.Module): """ Clay Segmentor class that combines the Encoder with FPN layers for semantic segmentation. Attributes: num_classes (int): Number of output classes for segmentation. feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the checkpoint file. """ def __init__(self, num_classes, feature_maps, ckpt_path): super().__init__() # Default values are for the clay mae base model. self.encoder = SegmentEncoder( mask_ratio=0.0, patch_size=8, shuffle=False, dim=768, depth=12, heads=12, dim_head=64, mlp_ratio=4.0, feature_maps=feature_maps, ckpt_path=ckpt_path, ) self.upsamples = [nn.Upsample(scale_factor=2**i) for i in range(4)] + [ nn.Upsample(scale_factor=4), ] self.fusion = nn.Conv2d(self.encoder.dim CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 5, self.encoder.dim, kernel_size=1) self.seg_head = nn.Conv2d(self.encoder.dim, num_classes, kernel_size=1) def forward(self, datacube): """ Forward pass of the Segmentor. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The segmentation logits. """ features = self.encoder(datacube) for i in range(len(features)): features[i] = self.upsamples[i](features[i]) fused = torch.cat(features, dim=1) fused = self.fusion(fused) logits = self.seg_head(fused) return logits
./finetune/segment/segment.py: """ Command line interface to run the neural network model! From the project root directory, do: python segment.py fit --config configs/segment_chesapeake.yaml References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from finetune.segment.chesapeake_datamodule import ChesapeakeDataModule # noqa: F401 from finetune.segment.chesapeake_model import ChesapeakeSegmentor # noqa: F401 # %% def cli_main(): """ Command-line inteface to run Segmentation Model with ChesapeakeDataModule. """ cli = LightningCLI(ChesapeakeSegmentor, ChesapeakeDataModule) return cli # %% if __name__ == "__main__": cli_main() print("Done!")
./finetune/regression/preprocess_data.py: import random from multiprocessing import Pool from pathlib import Path from typing import List import click import numpy as np from tifffile import imread EXPECTED_NR_OF_FILES_PER_TILE = 24 MONTHS = [ "00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", ] def list_uniqe_ids(src: Path) -> List[str]: ids = list(set(dat.name.split("_")[0] for dat in src.glob("*.tif"))) print(f"Found {len(ids)} unique tile ids") return ids def process_data_for_id( id: str, feature_path: Path, cubes_path: Path, overwrite: bool ) -> None: if not overwrite and (cubes_path / f"biomasters_cube_{id}.npz").exists(): print(f"Found existing file for {id}, skipping.") return data = [] for month in MONTHS: data_month = [] for platform in ["S1", "S2"]: feature_name = f"{id}_{platform}_{month}.tif" if not Path(feature_path / feature_name).exists(): continue file_data = ( imread(feature_path / feature_name).swapaxes(1, 2).swapaxes(0, 1) ) ND1 = 0 ND2 = -9999 if platform == "S1": # Limit to first orbit (the other is mostly nodata) file_data = file_data[:2] file_data = np.ma.array( file_data, mask=np.logical_or(file_data == ND1, file_data == ND2) ) else: file_data = file_data[:10] file_data = np.ma.array(file_data, mask=file_data == ND1) data_month.append(file_data) data_month = np.ma.vstack(data_month) NR_OF_BANDS_EXPECTED = 12 if data_month.shape[0] != NR_OF_BANDS_EXPECTED: continue data.append(data_month) cube = np.ma.array(data) mean_cube = np.ma.mean(cube, axis=0) if np.sum(mean_cube.mask): print("Nodata", np.sum(mean_cube.mask)) NODATA_THRESHOLD = 1e5 if np.sum(mean_cube.mask) > NODATA_THRESHOLD: print("Skipping due to lots of nodata") return np.savez_compressed(cubes_path / f"biomasters_cube_{id}.npz", cube=mean_cube) @click.command() @click.option( "--features", help="Folder with features (training or test)", type=click.Path(path_type=Path), ) @click.option( "--cubes", help="Folder to write the datacubes", type=click.Path(path_type=Path) ) @click.option( "--processes", default=1, help="How many processes to use for parallel processing", type=click.INT, ) @click.option( "--sample", default=0.05, help="Fraction of original data to sample", type=click.FloatRange(0, 1), ) @click.option( "--overwrite", is_flag=True, help="Overwrite existing cubes", ) def process(features, cubes, processes, sample, overwrite): """ Combine tiff files into npz datacubes. The datacubes will contain the S1 vv/vh bands for asc and desc orbits, stacked with the first 10 S2 bands. """ ids = list_uniqe_ids(features) if sample < 1: sample_length = int(len(ids) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py sample) random.seed(42) ids = random.sample(ids, sample_length) print(f"Subsampled {len(ids)} tiles") if processes > 1: features = [features] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py len(ids) cubes = [cubes] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py len(ids) overwrite = [overwrite] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py len(ids) with Pool(processes) as pl: pl.starmap(process_data_for_id, zip(ids, features, cubes, overwrite)) else: for id in ids: process_data_for_id(id, features, cubes, overwrite) if __name__ == "__main__": process()
./finetune/regression/regression.py: """ Command line interface to run the neural network model! From the project root directory, do: python regression.py fit --config configs/regression_biomasters.yaml References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from finetune.regression.biomasters_datamodule import BioMastersDataModule # noqa: F401 from finetune.regression.biomasters_model import BioMastersClassifier # noqa: F401 # %% def cli_main(): """ Command-line inteface to run Regression with BioMastersDataModule. """ cli = LightningCLI( BioMastersClassifier, BioMastersDataModule, save_config_kwargs={"overwrite": True}, ) return cli # %% if __name__ == "__main__": cli_main() print("Done!")
./finetune/regression/factory.py: """ Clay Segmentor for semantic segmentation tasks. Attribution: Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation with Transformers Paper URL: https://arxiv.org/abs/2105.15203 """ import re import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import nn from src.model import Encoder class SegmentEncoder(Encoder): """ Encoder class for segmentation tasks, incorporating a feature pyramid network (FPN). Attributes: feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the clay checkpoint file. """ def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, feature_maps, ckpt_path=None, ): super().__init__( mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, ) self.feature_maps = feature_maps # Define Feature Pyramid Network (FPN) layers self.fpn1 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), nn.BatchNorm2d(dim), nn.GELU(), nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn5 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) # Set device self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) # Load model from checkpoint if provided self.load_from_ckpt(ckpt_path) def load_from_ckpt(self, ckpt_path): """ Load the model's state from a checkpoint file. Args: ckpt_path (str): The path to the checkpoint file. """ if ckpt_path: # Load checkpoint ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt.get("state_dict") # Prepare new state dict with the desired subset and naming new_state_dict = { re.sub(r"^model\.encoder\.", "", name): param for name, param in state_dict.items() if name.startswith("model.encoder") } # Load the modified state dict into the model model_state_dict = self.state_dict() for name, param in new_state_dict.items(): if ( name in model_state_dict and param.size() == model_state_dict[name].size() ): model_state_dict[name].copy_(param) else: print(f"No matching parameter for {name} with size {param.size()}") # Freeze the loaded parameters for name, param in self.named_parameters(): if name in new_state_dict: param.requires_grad = False def forward(self, datacube): """ Forward pass of the SegmentEncoder. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: list: A list of feature maps extracted from the datacube. """ cube, time, latlon, gsd, waves = ( datacube["pixels"], # [B C H W] datacube["time"], # [B 2] datacube["latlon"], # [B 2] datacube["gsd"], # 1 datacube["waves"], # [N] ) B, C, H, W = cube.shape # Patchify and create embeddings per patch patches, waves_encoded = self.to_patch_embed(cube, waves) # [B L D] patches = self.add_encodings(patches, time, latlon, gsd) # [B L D] # Add class tokens cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] features = [] for idx, (attn, ff) in enumerate(self.transformer.layers): patches = attn(patches) + patches patches = ff(patches) + patches if idx in self.feature_maps: _cube = rearrange( patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8 ) features.append(_cube) # patches = self.transformer.norm(patches) # _cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H//8, W=W//8) # features.append(_cube) # Apply FPN layers ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4, self.fpn5] for i in range(len(features)): features[i] = ops[i](features[i]) return features class FusionBlock(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1) self.bn = nn.BatchNorm2d(output_dim) def forward(self, x): x = F.relu(self.bn(self.conv(x))) return x class SegmentationHead(nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.conv1 = nn.Conv2d(input_dim, input_dim // 2, kernel_size=3, padding=1) self.conv2 = nn.Conv2d( input_dim // 2, num_classes, kernel_size=1 ) # final conv to num_classes self.bn1 = nn.BatchNorm2d(input_dim // 2) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = self.conv2(x) # No activation before final layer return x class Regressor(nn.Module): """ Clay Regressor class that combines the Encoder with FPN layers for semantic regression. Attributes: num_classes (int): Number of output classes for segmentation. feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the checkpoint file. """ def __init__(self, num_classes, feature_maps, ckpt_path): super().__init__() # Default values are for the clay mae base model. self.encoder = SegmentEncoder( mask_ratio=0.0, patch_size=8, shuffle=False, dim=768, depth=12, heads=12, dim_head=64, mlp_ratio=4.0, feature_maps=feature_maps, ckpt_path=ckpt_path, ) self.upsamples = [nn.Upsample(scale_factor=2**i) for i in range(5)] self.fusion = FusionBlock(self.encoder.dim, self.encoder.dim // 4) self.seg_head = nn.Conv2d( self.encoder.dim // 4, num_classes, kernel_size=3, padding=1 ) def forward(self, datacube): """ Forward pass of the Regressor. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The segmentation logits. """ features = self.encoder(datacube) for i in range(len(features)): features[i] = self.upsamples[i](features[i]) # fused = torch.cat(features, dim=1) fused = torch.sum(torch.stack(features), dim=0) fused = self.fusion(fused) logits = self.seg_head(fused) return logits
./finetune/regression/biomasters_model.py: import lightning as L import torch import torch.nn.functional as F from torch import nn, optim from torchmetrics import MeanSquaredError from finetune.regression.factory import Regressor class NoNaNRMSE(nn.Module): def __init__(self, threshold=400): super().__init__() self.threshold = threshold def forward(self, logits, target): not_nan = target < self.threshold # logits = logits.squeeze(1) diff = logits - target diff[~not_nan] = 0 diff2 = torch.square(diff) diff2m = (diff2 / not_nan.sum((-1, -2, -3), keepdim=True)).sum((-1, -2, -3)) diff2msqrt = torch.sqrt(diff2m) rmse = diff2msqrt.mean() return rmse class BioMastersClassifier(L.LightningModule): """ LightningModule for training and evaluating a regression on the BioMasters dataset. Args: num_classes (int): Number of classes for classification. ckpt_path (str): Clay MAE pretrained checkpoint path. lr (float): Learning rate for the optimizer. wd (float): Weight decay for the optimizer. b1 (float): Beta1 parameter for the Adam optimizer. b2 (float): Beta2 parameter for the Adam optimizer. """ def __init__(self, ckpt_path, feature_maps, lr, wd, b1, b2): # noqa: PLR0913 super().__init__() self.save_hyperparameters() # self.model = Classifier(num_classes=1, ckpt_path=ckpt_path) self.model = Regressor( num_classes=1, feature_maps=feature_maps, ckpt_path=ckpt_path ) self.loss_fn = NoNaNRMSE() self.score_fn = MeanSquaredError() def forward(self, datacube): """ Forward pass through the classifier. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The output logits from the classifier. """ # Wavelengths for S1 and S2 bands of BioMasters dataset waves = torch.tensor( [ # 3.5, # S1 # 4.0, 0.493, # S2 0.56, 0.665, 0.704, 0.74, 0.783, 0.842, 0.865, 1.61, 2.19, ] ) gsd = torch.tensor(10.0) return self.model( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, } ) def configure_optimizers(self): """ Configure the optimizer and learning rate scheduler. Returns: dict: A dictionary containing the optimizer and learning rate scheduler. """ optimizer = optim.AdamW( [ param for name, param in self.model.named_parameters() if param.requires_grad ], lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "epoch", }, } def shared_step(self, batch, batch_idx, phase): """ Perform a shared step for both training and validation. Args: batch (dict): A batch of data. batch_idx (int): The index of the batch. phase (str): The phase ('train' or 'val'). Returns: torch.Tensor: The computed loss for the batch. """ labels = batch["label"] logits = self(batch) logits = F.interpolate( logits, size=(256, 256), mode="bilinear", align_corners=False, ) # Resize to match labels size # print("Logits shape", logits.shape) # print("Labels shape", labels.shape) loss = self.loss_fn(logits, labels) score = self.score_fn(logits, labels) # Convert to RMSE score = torch.sqrt(score) self.log( f"{phase}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/score", score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch, batch_idx): """ Perform a training step. Args: batch (dict): A batch of training data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): """ Perform a validation step. Args: batch (dict): A batch of validation data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "val")
./finetune/regression/biomasters_datamodule.py: """ DataModule for the BioMasters dataset for a regression task. BioMassters: A Benchmark Dataset for Forest Biomass Estimation using Multi-modal Satellite Time-series https://nascetti-a.github.io/BioMasster/ This implementation provides a structured way to handle the data loading and preprocessing required for training and validating a regression model. Citation: Andrea Nascetti, Ritu Yadav, Kirill Brodt, Qixun Qu, Hongwei Fan, Yuri Shendryk, Isha Shah, and Christine Chung, BioMassters: A Benchmark Dataset for Forest Biomass Estimation using Multi-modal Satellite Time-series, Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track, 2023, https://openreview.net/forum?id=hrWsIC4Cmz """ from pathlib import Path import lightning as L import numpy as np import torch import yaml from box import Box from tifffile import imread from torch.utils.data import DataLoader, Dataset from torchvision.transforms import v2 class BioMastersDataset(Dataset): """ Dataset class for the BioMasters regression dataset. Assumes band order vv, vh, vv, vh, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12 Args: chip_dir (str): Directory containing the image chips. label_dir (str): Directory containing the labels. """ def __init__(self, chip_dir, label_dir, metadata): self.chip_dir = Path(chip_dir) self.label_dir = Path(label_dir) self.metadata = metadata # Load statistics from Clay metadata s2_mean = list(metadata["sentinel-2-l2a"].bands.mean.values()) s2_std = list(metadata["sentinel-2-l2a"].bands.std.values()) # Duplicate the S1 statistics so that the asc/desc orbit data # is handled correctly self.transform = self.create_transforms( mean=s2_mean, std=s2_std, ) # Load chip and label file names self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npz")] print(f"Found {len(self.chips)} chips to process for {chip_dir}") def create_transforms(self, mean, std): """ Create normalization transforms. Args: mean (list): Mean values for normalization. std (list): Standard deviation values for normalization. Returns: torchvision.transforms.Compose: A composition of transforms. """ return v2.Compose( [ v2.Normalize(mean=mean, std=std), ], ) def __len__(self): return len(self.chips) def __getitem__(self, idx): """ Get a sample from the dataset. Args: idx (int): Index of the sample. Returns: dict: A dictionary containing the image, label, and additional information. """ chip_name = self.chip_dir / self.chips[idx] label_name = self.label_dir / (chip_name.stem.split("_")[-1] + "_agbm.tif") chip = np.load(chip_name)["cube"][2:, ...].astype("float32") label = imread(label_name).astype("float32") label = np.expand_dims(label, 0) sample = { "pixels": self.transform(torch.from_numpy(chip)), "label": torch.from_numpy(label), "time": torch.zeros(4), # Placeholder for time information "latlon": torch.zeros(4), # Placeholder for latlon information } return sample class BioMastersDataModule(L.LightningDataModule): """ DataModule class for the Chesapeake Bay dataset. Args: train_chip_dir (str): Directory containing training image chips. train_label_dir (str): Directory containing training labels. val_chip_dir (str): Directory containing validation image chips. val_label_dir (str): Directory containing validation labels. metadata_path (str): Path to the metadata file. batch_size (int): Batch size for data loading. num_workers (int): Number of workers for data loading. platform (str): Platform identifier used in metadata. """ def __init__( # noqa: PLR0913 self, train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, metadata_path, batch_size, num_workers, ): super().__init__() self.train_chip_dir = train_chip_dir self.train_label_dir = train_label_dir self.val_chip_dir = val_chip_dir self.val_label_dir = val_label_dir self.metadata = Box(yaml.safe_load(open(metadata_path))) self.batch_size = batch_size self.num_workers = num_workers def setup(self, stage=None): """ Setup datasets for training and validation. Args: stage (str): Stage identifier ('fit' or 'test'). """ if stage in {"fit", None}: self.trn_ds = BioMastersDataset( self.train_chip_dir, self.train_label_dir, self.metadata, ) self.val_ds = BioMastersDataset( self.val_chip_dir, self.val_label_dir, self.metadata, ) def train_dataloader(self): """ Create DataLoader for training data. Returns: DataLoader: DataLoader for training dataset. """ return DataLoader( self.trn_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, ) def val_dataloader(self): """ Create DataLoader for validation data. Returns: DataLoader: DataLoader for validation dataset. """ return DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, )
./trainer.py: """ Command line interface to run the neural network model! From the project root directory, do: python trainer.py fit References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from src.datamodule import ClayDataModule # noqa: F401 from src.model import ClayMAEModule # noqa: F401 # %% def cli_main(): """ Command-line inteface to run ClayMAE with ClayDataModule. """ cli = LightningCLI(save_config_kwargs={"overwrite": True}) return cli # %% if __name__ == "__main__": cli_main() print("Done!")
./src/callbacks_wandb.py: """ Lightning callback functions for logging to Weights & Biases. Includes a way to visualize RGB images derived from the raw logits of a Masked Autoencoder's decoder during the validation loop. I.e. to see if the Vision Transformer model is learning how to do image reconstruction. Usage: ``` import lightning as L from src.callbacks_wandb import LogMAEReconstruction trainer = L.Trainer( ..., callbacks=[LogMAEReconstruction(num_samples=6)] ) ``` References: - https://lightning.ai/docs/pytorch/2.1.0/common/trainer.html#callbacks - https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY - https://github.com/ashleve/lightning-hydra-template/blob/wandb-callbacks/src/callbacks/wandb_callbacks.py#L245 """ import lightning as L import matplotlib.pyplot as plt import numpy as np import skimage import torch from einops import rearrange try: import wandb except ImportError: wandb = None # %% def get_wandb_logger(trainer: L.Trainer) -> L.pytorch.loggers.WandbLogger: """ Safely get Weights & Biases logger from Trainer. """ if trainer.fast_dev_run: raise Exception( "Cannot use wandb callbacks since pytorch lightning disables " "loggers in `fast_dev_run=true` mode." ) for logger in trainer.loggers: if isinstance(logger, L.pytorch.loggers.WandbLogger): return logger break raise Exception( "You are using wandb related callback, " "but WandbLogger was not found for some reason..." ) class LogMAEReconstruction(L.Callback): """ Logs reconstructed RGB images from a Masked Autoencoder's decoder to WandB. """ def __init__(self, num_samples: int = 8): """ Define how many sample images to log. Parameters ---------- num_samples : int The number of RGB image samples to upload to WandB. Default is 8. """ super().__init__() self.num_samples: int = num_samples self.ready: bool = False if wandb is None: raise ModuleNotFoundError( "Package `wandb` is required to be installed to use this callback. " "Please use `pip install wandb` or " "`conda install -c conda-forge wandb` " "to install the package" ) def on_sanity_check_start(self, trainer, pl_module): """ Don't execute callback before validation sanity checks are completed. """ self.ready = False def on_sanity_check_end(self, trainer, pl_module): """ Start executing callback only after all validation sanity checks end. """ self.ready = True def on_validation_batch_end( self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: dict[str, torch.Tensor], batch: dict[str, torch.Tensor | list[str]], batch_idx: int, ) -> list: """ Called in the validation loop at the start of every mini-batch. Gather a sample of data from the first mini-batch, get the RGB bands, apply histogram equalization to the image, and log it to WandB. """ if self.ready and batch_idx == 0: # only run on first mini-batch with torch.inference_mode(): # Get WandB logger self.logger = get_wandb_logger(trainer=trainer) # Turn raw logits into reconstructed 512x512 images patchified_pixel_values: torch.Tensor = outputs["logits"] # assert patchified_pixel_values.shape == torch.Size([32, 64, 53248]) y_hat: torch.Tensor = pl_module.vit.unpatchify( patchified_pixel_values=patchified_pixel_values ) # assert y_hat.shape == torch.Size([32, 13, 512, 512]) # Reshape tensors from channel-first to channel-last x: torch.Tensor = torch.einsum( "bchw->bhwc", batch["image"][: self.num_samples] ) y_hat: torch.Tensor = torch.einsum( "bchw->bhwc", y_hat[: self.num_samples] ) # assert y_hat.shape == torch.Size([8, 512, 512, 13]) assert x.shape == y_hat.shape # Plot original and reconstructed RGB images of Sentinel-2 rgb_original: np.ndarray = ( x[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy() ) rgb_reconstruction: np.ndarray = ( y_hat[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy() ) figures: list[wandb.Image] = [] for i in range(min(x.shape[0], self.num_samples)): img_original = wandb.Image( data_or_path=skimage.exposure.equalize_hist( image=rgb_original[i] ), caption=f"RGB Image {i}", ) figures.append(img_original) img_reconstruction = wandb.Image( data_or_path=skimage.exposure.equalize_hist( image=rgb_reconstruction[i] ), caption=f"Reconstructed {i}", ) figures.append(img_reconstruction) # Upload figures to WandB self.logger.experiment.log(data={"Examples": figures}) return figures class LogIntermediatePredictions(L.Callback): """Visualize the model results at the end of every epoch.""" def __init__(self): """ Instantiates with wandb-logger. """ super().__init__() def on_validation_end( self, trainer: L.Trainer, pl_module: L.LightningModule, ) -> None: """ Called when the validation loop ends. At the end of each epoch, takes the first batch from validation dataset & logs the model predictions to wandb-logger for humans to interpret how model evolves over time. """ with torch.no_grad(): # Get WandB logger self.logger = get_wandb_logger(trainer=trainer) # get the val dataloader val_dl = iter(trainer.val_dataloaders) for i in range(6): batch = next(val_dl) platform = batch["platform"][0] batch = { k: v.to(pl_module.device) for k, v in batch.items() if isinstance(v, torch.Tensor) } waves = torch.tensor( list( trainer.datamodule.metadata[platform].bands.wavelength.values() ) ) gsd = torch.tensor(trainer.datamodule.metadata[platform].gsd) # ENCODER ( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) = pl_module.model.encoder( { "pixels": batch["pixels"], "time": batch["time"], "latlon": batch["latlon"], "gsd": gsd, "waves": waves, } ) # DECODER pixels, waves = pl_module.model.decoder( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, batch["time"], batch["latlon"], gsd, waves, ) # pixels: batch x (patch x patch) x 1024 pixels = rearrange( pixels, "b (h w) (c p1 p2) -> b c (h p1) (w p2)", p1=pl_module.model.patch_size, p2=pl_module.model.patch_size, h=trainer.datamodule.size // pl_module.model.patch_size, w=trainer.datamodule.size // pl_module.model.patch_size, ) assert pixels.shape == batch["pixels"].shape n_rows = 4 # 2 for actual and 2 for predicted n_cols = 8 fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 8)) for j in range(n_cols): # Plot actual images in rows 0 and 2 axs[0, j].imshow( batch["pixels"][j][0].detach().cpu().numpy(), cmap="viridis" ) axs[0, j].set_title(f"Actual {j}") axs[0, j].axis("off") axs[2, j].imshow( batch["pixels"][j + n_cols][0].detach().cpu().numpy(), cmap="viridis", ) axs[2, j].set_title(f"Actual {j+n_cols}") axs[2, j].axis("off") # Plot predicted images in rows 1 and 3 axs[1, j].imshow( pixels[j][0].detach().cpu().numpy(), cmap="viridis" ) axs[1, j].set_title(f"Pred {j}") axs[1, j].axis("off") axs[3, j].imshow( pixels[j + n_cols][0].detach().cpu().numpy(), cmap="viridis" ) axs[3, j].set_title(f"Pred {j+n_cols}") axs[3, j].axis("off") self.logger.experiment.log({f"{platform}": wandb.Image(fig)}) plt.close(fig)
./src/factory.py: """Dynamic Embedding from DOFA paper. Reference: - https://arxiv.org/abs/2403.15356 - https://github.com/zhu-xlab/DOFA """ import torch import torch.nn.functional as F from einops import rearrange from torch import nn from src.utils import posemb_sincos_1d class FCBlock(nn.Module): def __init__(self, size): super().__init__() self.l1 = nn.Linear(size, size) self.l2 = nn.Linear(size, size) def forward(self, x): y = F.gelu(self.l1(x)) y = F.gelu(self.l2(y)) return x + y class WavesTransformer(nn.Module): def __init__( # noqa: PLR0913 self, wave_dim, output_dim, num_latent_tokens, embed_dim, is_decoder, num_heads=4, num_layers=1, ): super().__init__() self.num_latent_tokens = num_latent_tokens self.is_decoder = is_decoder layer = nn.TransformerEncoderLayer( d_model=wave_dim, nhead=num_heads, activation="gelu", dropout=0, norm_first=False, batch_first=False, ) self.encoder = nn.TransformerEncoder(layer, num_layers) self.fc_weight = nn.Linear(wave_dim, output_dim) self.fc_bias = None if self.is_decoder else nn.Linear(wave_dim, embed_dim) self.weight_tokens = nn.Parameter( torch.randn(self.num_latent_tokens, wave_dim) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02 ) self.bias_token = nn.Parameter(torch.randn(1, wave_dim) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02) def forward(self, x): x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0) out = self.encoder(x) weights = self.fc_weight( out[self.num_latent_tokens : -1] + x[self.num_latent_tokens : -1] ) bias = None if self.is_decoder else self.fc_bias(out[-1]) return weights, bias class DynamicEmbedding(nn.Module): def __init__( self, wave_dim, num_latent_tokens, patch_size, embed_dim, is_decoder=False, ): super().__init__() self.wave_dim = wave_dim self.num_latent_tokens = num_latent_tokens self.patch_size = patch_size self.embed_dim = embed_dim self.is_decoder = is_decoder self.output_dim = (patch_size**2) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py embed_dim self.weight_generator = WavesTransformer( wave_dim, self.output_dim, self.num_latent_tokens, self.embed_dim, is_decoder, ) self.fclayer = FCBlock(self.wave_dim) self.initialize_weights() def forward(self, batch, waves): waves = posemb_sincos_1d(waves, self.wave_dim) waves = waves.to(batch.device) waves = self.fclayer(waves) weight, bias = self.weight_generator(waves) if self.is_decoder: dynamic_weight = rearrange( weight, "cin (k1 k2 cout) -> (cin k1 k2) cout", k1=self.patch_size, k2=self.patch_size, cout=self.embed_dim, ) if bias is not None: bias = rearrange(bias, "b -> (b)") dynamic_out = F.linear(batch, dynamic_weight CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02, bias=bias) x = dynamic_out else: dynamic_weight = rearrange( weight, "cin (cout k1 k2) -> cout cin k1 k2", k1=self.patch_size, k2=self.patch_size, ) if bias is not None: bias = rearrange(bias, "b -> (b)") dynamic_out = F.conv2d( batch, dynamic_weight CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02, bias=bias, stride=self.patch_size ) x = rearrange(dynamic_out, "b c h w -> b (h w) c") return x, waves def initialize_weights(self): # Initialize weights using Xavier initialization for m in self.modules(): if isinstance(m, (nn.Linear, nn.Conv2d)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0)
./src/model.py: import math import os from typing import Literal import lightning as L import timm import torch import torch.nn.functional as F import yaml from box import Box from einops import rearrange, reduce, repeat from torch import nn from torchvision.transforms import v2 from vit_pytorch.simple_vit import Transformer from src.factory import DynamicEmbedding from src.utils import posemb_sincos_2d_with_gsd torch.set_float32_matmul_precision("medium") os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1" class Encoder(nn.Module): def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, ): super().__init__() self.mask_ratio = mask_ratio self.patch_size = patch_size self.shuffle = shuffle self.dim = dim self.cls_token = nn.Parameter(torch.randn(1, 1, dim) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02) self.patch_embedding = DynamicEmbedding( wave_dim=128, num_latent_tokens=128, patch_size=patch_size, embed_dim=dim, is_decoder=False, ) self.transformer = Transformer( dim=dim, depth=depth, heads=heads, dim_head=dim_head, mlp_dim=int(dim CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py mlp_ratio), ) def to_patch_embed(self, cube, waves): """Split the input cube into patches & create embeddings per patch""" patches, waves_encoded = self.patch_embedding(cube, waves) # [B L D] return patches, waves_encoded # ([B L D], [N D]) def add_encodings(self, patches, time, latlon, gsd): """Add position encoding to the patches""" B, L, D = patches.shape grid_size = int(math.sqrt(L)) self.num_patches = grid_size**2 pos_encoding = ( posemb_sincos_2d_with_gsd( h=grid_size, w=grid_size, dim=(self.dim - 8), gsd=gsd, ) .to(patches.device) .detach() ) # [L (D - 8)] time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8] pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] pos_metadata_encoding = torch.cat( (pos_encoding, time_latlon), dim=-1 ) # [B L D] patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D] return patches # [B L D] def mask_out(self, patches): """ Mask out patches randomly by shuffling the patches & masking out the first N patches Parameters ---------- patches : torch.Tensor A tensor of shape (B, L, D) Returns ------- unmasked_patches : torch.Tensor A tensor of shape (B, L:(1 - mask_ratio), D) containing the embeddings of the unmasked patches. unmasked_indices : torch.Tensor A tensor of shape (B, (1 - mask_ratio)) containing the indices of the unmasked patches. masked_indices : torch.Tensor A tensor of shape (B, mask_ratio) containing the indices of the masked patches. masked_matrix : torch.Tensor A tensor of shape (B, L) containing the mask matrix, 1 indicates a masked patch & 0 indicates an unmasked patch. """ B, L, D = patches.shape # assert ( # L == self.num_patches # ), f"Expected {self.num_patches} patches, got {L} patches." if self.shuffle: # Shuffle the patches noise = torch.randn((B, L), device=patches.device) # [B L] else: # Don't shuffle, useful for interpolation & inspection of embeddings noise = rearrange( torch.arange(B CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L, device=patches.device), "(B L) -> B L", B=B, L=L ) random_indices = torch.argsort(noise, dim=-1) # [B L] reverse_indices = torch.argsort(random_indices, dim=-1) # [B L] num_masked_patches = int( self.mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py self.num_patches ) # Number of patches to be masked out masked_indices, unmasked_indices = ( random_indices[:, :num_masked_patches], # [B mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L] random_indices[:, num_masked_patches:], # [B (1 - mask_ratio) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L] ) # create a mask of shape B L, where 1 indicates a masked patch # and 0 indicates an unmasked patch masked_matrix = torch.zeros((B, L), device=patches.device) # [B L] = 0 masked_matrix[:, :num_masked_patches] = 1 # [B mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L] = 1 masked_matrix = torch.gather( masked_matrix, dim=1, index=reverse_indices ) # [B L] -> [B L] - reorder the patches # mask out the patches batch_indices = rearrange( torch.arange(B, device=patches.device), "B -> B 1" ) # [B 1] unmasked_patches = patches[ batch_indices, unmasked_indices, : ] # [B L:(1 - mask_ratio) D] _ = patches[batch_indices, masked_indices, :] # [B L:mask_ratio D] return ( unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] def forward(self, datacube): cube, time, latlon, gsd, waves = ( datacube["pixels"], # [B C H W] datacube["time"], # [B 2] datacube["latlon"], # [B 2] datacube["gsd"], # 1 datacube["waves"], # [N] ) # [B C H W] B, C, H, W = cube.shape patches, waves_encoded = self.to_patch_embed( cube, waves ) # [B L D] - patchify & create embeddings per patch # TODO: Add time & latlon as encoding to patches patches = self.add_encodings( patches, time, latlon, gsd, ) # [B L D] - add position encoding to the embeddings # mask out patches ( unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) = self.mask_out( patches ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] # Add class tokens cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] unmasked_patches = torch.cat( (cls_tokens, unmasked_patches), dim=1 ) # [B (1 + L) D] # pass the unmasked patches through the transformer encoded_unmasked_patches = self.transformer( unmasked_patches ) # [B ((1 + L)):(1 - mask_ratio)) D] return ( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) # [B ((1 + L):(1 - mask_ratio)) D], [(1-mask_ratio)], [mask_ratio], [B L] class Decoder(nn.Module): def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, encoder_dim, dim, depth, heads, dim_head, mlp_ratio, ): super().__init__() self.mask_ratio = mask_ratio self.patch_size = patch_size self.encoder_dim = encoder_dim self.dim = dim self.enc_to_dec = ( nn.Linear(encoder_dim, dim) if encoder_dim != dim else nn.Identity() ) self.mask_patch = nn.Parameter(torch.randn(dim)) self.transformer = Transformer( dim=dim, depth=depth, heads=heads, dim_head=dim_head, mlp_dim=int(dim CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py mlp_ratio), ) self.embed_to_pixels = DynamicEmbedding( wave_dim=128, num_latent_tokens=128, patch_size=patch_size, embed_dim=dim, is_decoder=True, ) def reconstruct_and_add_encoding( # noqa: PLR0913 self, unmasked_patches, unmasked_indices, masked_indices, masked_matrix, time, latlon, gsd, ): B, L = masked_matrix.shape grid_size = int(math.sqrt(L)) self.num_patches = grid_size**2 cls_tokens, unmasked_patches = ( unmasked_patches[:, :1, :], unmasked_patches[:, 1:, :], ) # [B 1 D], [B L:(1 - mask_ratio) D] pos_encoding = ( posemb_sincos_2d_with_gsd( h=grid_size, w=grid_size, dim=(self.dim - 8), gsd=gsd ) .to(unmasked_patches.device) .detach() ) # [L D] time_latlon = ( torch.hstack((time, latlon)).to(unmasked_patches.device).detach() ) # [B 8] pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] pos_metadata_encoding = torch.cat( (pos_encoding, time_latlon), dim=-1 ) # [B L D] batch_indices = rearrange( torch.arange(B, device=unmasked_patches.device), "B -> B 1" ) # [B 1] num_masked_patches = int(self.mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py self.num_patches) masked_patches = repeat( self.mask_patch, "D -> B L D", B=B, L=num_masked_patches ) # [B L:mask_ratio D] # Add position encoding masked_patches = ( masked_patches + pos_metadata_encoding[batch_indices, masked_indices, :] ) # [B L:mask_ratio D] + [B L:mask_ratio D] unmasked_patches = ( unmasked_patches + pos_metadata_encoding[batch_indices, unmasked_indices, :] ) # [B GL:(1 - masked_ratio) D] + [B GL:(1 - mask_ratio) D] # Concatenate the masked & unmasked patches decoder_patches = torch.zeros( (B, self.num_patches, self.dim), device=unmasked_patches.device ) # [B L D] decoder_patches[batch_indices, unmasked_indices, :] = ( unmasked_patches # [B L:(1 - mask_ratio) D]) ) decoder_patches[batch_indices, masked_indices, :] = ( masked_patches # [B L:mask_ratio D]) ) decoder_patches = torch.cat( (cls_tokens, decoder_patches), dim=1 ) # [B (1 + L) D] return decoder_patches # [B (1 + L) D] def forward( # noqa: PLR0913 self, encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, time, latlon, gsd, waves, ): # Change the embedding dimension from encoder to decoder encoded_unmasked_patches = self.enc_to_dec( encoded_unmasked_patches ) # [B (1 + L) D] # Reconstruct the patches to feed into the decoder transformer decoder_patches = self.reconstruct_and_add_encoding( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, time, latlon, gsd, ) # [B (1 + L) D] # Pass the decoder patches through the transformer decoded_patches = self.transformer(decoder_patches) # [B (1 + L) D] pixels, waves = self.embed_to_pixels( decoded_patches, waves ) # [B (1 + L) (C P P)] # Remove the class token pixels = pixels[:, 1:, :] return pixels, waves # [B L (C P P)], [B N] class ClayMAE(nn.Module): def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, norm_pix_loss, shuffle, metadata, teacher, # ENCODER dim, depth, heads, dim_head, mlp_ratio, # DECODER decoder_dim, decoder_depth, decoder_heads, decoder_dim_head, decoder_mlp_ratio, **kwargs, ): super().__init__() self.mask_ratio = mask_ratio self.patch_size = patch_size self.norm_pix_loss = norm_pix_loss self.shuffle = shuffle self.metadata = metadata self.teacher = timm.create_model(teacher, pretrained=True, num_classes=0) self.teacher_chip_size = 224 self.teacher_resize = v2.Resize( size=(self.teacher_chip_size, self.teacher_chip_size) ) self.proj = nn.Linear(dim, self.teacher.num_features) self.encoder = Encoder( mask_ratio=mask_ratio, patch_size=patch_size, shuffle=shuffle, dim=dim, depth=depth, heads=heads, dim_head=dim_head, mlp_ratio=mlp_ratio, ) self.decoder = Decoder( mask_ratio=mask_ratio, patch_size=patch_size, encoder_dim=dim, dim=decoder_dim, depth=decoder_depth, heads=decoder_heads, dim_head=decoder_dim_head, mlp_ratio=decoder_mlp_ratio, ) self.freeze_teacher() def freeze_teacher(self): for param in self.teacher.parameters(): param.requires_grad = False def per_pixel_loss(self, cube, pixels, masked_matrix): """ cube: [B C H W] pixels: [B L (C P P)] masked_matrix: [B L], 0 is unmasked, 1 is masked """ patches = rearrange( cube, "B C (h p1) (w p2) -> B (h w) (C p1 p2)", p1=self.patch_size, p2=self.patch_size, ) # [B L (C P P)] if self.norm_pix_loss: mean = patches.mean(dim=-1, keepdim=True) var = patches.var(dim=-1, keepdim=True) patches = (patches - mean) / (var + 1e-6) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.5 loss = F.l1_loss(patches, pixels, reduction="none") # loss per pixel loss = reduce(loss, "B L D -> B L", reduction="mean") # loss per patch loss = ( loss CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py masked_matrix ).sum() / masked_matrix.sum() # loss on masked patches only return loss def forward(self, datacube): """ datacube: dict containing the following keys: - pixels: [B C H W] - time: [B 4] # week hour - latlon: [B 4] # lat lon - platform: [B 1] - date: [B 1] """ platform = datacube["platform"][0] waves = torch.tensor(list(self.metadata[platform].bands.wavelength.values())) gsd = torch.tensor(self.metadata[platform].gsd) # ENCODER ( encoded_unmasked_patches, # [B (1 + L):(1 - mask_ratio) D] unmasked_indices, # [(1-mask_ratio)] masked_indices, # [mask_ratio] masked_matrix, # [B L] ) = self.encoder( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, } ) # DECODER pixels, waves = self.decoder( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, datacube["time"], datacube["latlon"], gsd, waves, ) # [B L (C P P)] # LOSS reconstruction_loss = self.per_pixel_loss( datacube["pixels"], pixels, masked_matrix ) # TEACHER encoder_output = self.proj(encoded_unmasked_patches[:, 0, :]) # [B D'] with torch.no_grad(): if platform == "sentinel-1-rtc": r = datacube["pixels"][:, 0, :, :] g = datacube["pixels"][:, 1, :, :] b = r - g rgb = torch.stack((r, g, b), dim=1) else: # Read RGB bands from the sensor to feed the teacher model indices = self.metadata[platform].rgb_indices rgb = datacube["pixels"][:, indices, :, :] rgb = self.teacher_resize(rgb) teacher_output = self.teacher(rgb) representation_loss = -( F.cosine_similarity(encoder_output, teacher_output).mean() - 1.0 # change range from [-1, 1] to [-2, 0] ) # negative cosine similarity, [0, 2] -> 0 is similar & 2 is opposite loss = 0.90 CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py reconstruction_loss + 0.10 CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py representation_loss return (loss, reconstruction_loss, representation_loss) def clay_mae_tiny(**kwargs): args = { # ENCODER "dim": 192, "depth": 6, "heads": 4, "dim_head": 48, "mlp_ratio": 2, # DECODER "decoder_dim": 96, "decoder_depth": 3, "decoder_heads": 2, "decoder_dim_head": 48, "decoder_mlp_ratio": 2, } args.update(kwargs) return ClayMAE(**args) def clay_mae_small(**kwargs): args = { # ENCODER "dim": 384, "depth": 6, "heads": 6, "dim_head": 64, "mlp_ratio": 2, # DECODER "decoder_dim": 192, "decoder_depth": 4, "decoder_heads": 4, "decoder_dim_head": 64, "decoder_mlp_ratio": 2, } args.update(kwargs) return ClayMAE(**args) def clay_mae_base(**kwargs): args = { # ENCODER "dim": 768, "depth": 12, "heads": 12, "dim_head": 64, "mlp_ratio": 4, # DECODER "decoder_dim": 512, "decoder_depth": 6, "decoder_heads": 6, "decoder_dim_head": 64, "decoder_mlp_ratio": 4, } args.update(kwargs) return ClayMAE(**args) def clay_mae_large(**kwargs): args = { # ENCODER "dim": 1024, "depth": 24, "heads": 16, "dim_head": 64, "mlp_ratio": 4, # DECODER "decoder_dim": 512, "decoder_depth": 8, "decoder_heads": 8, "decoder_dim_head": 64, "decoder_mlp_ratio": 4, } args.update(kwargs) return ClayMAE(**args) class ClayMAEModule(L.LightningModule): def __init__( # noqa: PLR0913 self, model_size="base", mask_ratio=0.75, norm_pix_loss=False, patch_size=16, shuffle=False, metadata_path="configs/metadata.yaml", teacher="vit_base_patch16_224.dino", lr=1e-4, wd=0.05, b1=0.9, b2=0.95, embeddings_level: Literal["mean", "patch", "group"] = "mean", ): super().__init__() self.save_hyperparameters(logger=True) self.metadata = Box(yaml.safe_load(open(metadata_path))) model_map = { "tiny": clay_mae_tiny, "small": clay_mae_small, "base": clay_mae_base, "large": clay_mae_large, } if model_size in model_map: model_args = { "mask_ratio": mask_ratio, "patch_size": patch_size, "norm_pix_loss": norm_pix_loss, "shuffle": shuffle, "metadata": self.metadata, "teacher": teacher, } self.model = model_map[model_size](**model_args) else: raise ValueError( f"Invalid model size {model_size}. Expected one of {model_map.keys()}" ) def on_train_epoch_start(self): self.model.teacher.eval() def forward(self, datacube: dict[str, torch.Tensor]): return self.model(datacube) def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=1000, T_mult=2, eta_min=self.hparams.lr CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 100, last_epoch=-1 ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", }, } def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, phase: str): datacube = batch loss, reconstruction_loss, representation_loss = self(datacube) self.log( name=f"{phase}/loss", value=loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( name=f"{phase}/rec_loss", value=reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( name=f"{phase}/rep_loss", value=representation_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): return self.shared_step(batch, batch_idx, phase="train") def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int): return self.shared_step(batch, batch_idx, phase="val")
./src/utils.py: """ Code from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py """ import torch def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] x = x.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) def posemb_sincos_2d_with_gsd( h, w, dim, gsd=1.0, temperature: int = 10000, dtype=torch.float32 ): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py (2 CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega / dim)) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py (gsd / 1.0) # Adjusted for g y = y.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] x = x.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): assert ( dim % 2 == 0 ), "Feature dimension must be a multiple of 2 for sincos embedding" pos = torch.arange(pos) if isinstance(pos, int) else pos omega = torch.arange(dim // 2) / (dim // 2 - 1) omega = 1.0 / (temperature**omega) scaled_pos = pos[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) return pe.type(dtype)
./src/datamodule.py: """ LightningDataModule to load Earth Observation data from GeoTIFF files using rasterio. """ from collections import defaultdict from pathlib import Path from typing import List, Literal import lightning as L import numpy as np import torch import torchdata import yaml from box import Box from einops import rearrange from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import Sampler from torchvision.transforms import v2 class EODataset(Dataset): """Reads different Earth Observation data sources from a directory.""" def __init__( self, chips_path: List[Path], size: int, platforms: list, metadata: Box ) -> None: super().__init__() self.chips_path = chips_path self.size = size self.transforms = {} # Generate transforms for each platform using a helper function for platform in platforms: mean = list(metadata[platform].bands.mean.values()) std = list(metadata[platform].bands.std.values()) self.transforms[platform] = self.create_transforms(mean, std) def create_transforms(self, mean, std): return v2.Compose( [ v2.RandomHorizontalFlip(p=0.5), v2.RandomVerticalFlip(p=0.5), v2.RandomCrop(size=(self.size, self.size)), v2.Normalize(mean=mean, std=std), ] ) def __len__(self): return len(self.chips_path) def __getitem__(self, idx): chip_path = self.chips_path[idx] with np.load(chip_path, allow_pickle=False) as chip: pixels = torch.from_numpy(chip["pixels"].astype(np.float32)) platform = chip_path.parent.name pixels = self.transforms[platform](pixels) # Prepare additional information additional_info = { "platform": platform, "time": torch.tensor( np.hstack((chip["week_norm"], chip["hour_norm"])), dtype=torch.float32, ), "latlon": torch.tensor( np.hstack((chip["lat_norm"], chip["lon_norm"])), dtype=torch.float32 ), } return {"pixels": pixels, **additional_info} class ClaySampler(Sampler): def __init__(self, dataset, platforms, batch_size): self.dataset = dataset self.platforms = platforms self.batch_size = batch_size self.cubes_per_platform = {platform: [] for platform in platforms} for idx, chip_path in enumerate(self.dataset.chips_path): platform = chip_path.parent.name self.cubes_per_platform[platform].append(idx) def __iter__(self): cubes_per_platform_per_epoch = {} rng = np.random.default_rng() # Shuffle and adjust sizes max_len = max(len(indices) for indices in self.cubes_per_platform.values()) for platform in self.platforms: indices = self.cubes_per_platform[platform] rng.shuffle(indices) repeated_indices = np.tile(indices, (max_len // len(indices) + 1))[:max_len] cubes_per_platform_per_epoch[platform] = repeated_indices # Create batches such that we return one platform per batch in cycle # Ignore the last batch if it is incomplete for i in range(0, max_len, self.batch_size): for platform in self.platforms: batch = cubes_per_platform_per_epoch[platform][i : i + self.batch_size] if len(batch) == self.batch_size: yield batch def __len__(self): return len(self.dataset.chips_path) // self.batch_size def batch_collate(batch): """Collate function for DataLoader. Merge the first two dimensions of the input tensors. """ d = defaultdict(list) for item in batch: d["pixels"].append(item["pixels"]) d["time"].append(item["time"]) d["latlon"].append(item["latlon"]) d["platform"].append(item["platform"]) return { "pixels": rearrange(d["pixels"], "b1 b2 c h w -> (b1 b2) c h w"), "time": rearrange(d["time"], "b1 b2 t -> (b1 b2) t"), "latlon": rearrange(d["latlon"], "b1 b2 ll -> (b1 b2) ll"), "platform": d["platform"], } class ClayDataModule(L.LightningDataModule): def __init__( # noqa: PLR0913 self, data_dir: str = "data", size: int = 224, metadata_path: str = "configs/metadata.yaml", platforms: list = [ "landsat-c2l1", "landsat-c2l2-sr", "linz", "naip", "sentinel-1-rtc", "sentinel-2-l2a", ], batch_size: int = 10, num_workers: int = 8, ): super().__init__() self.data_dir = data_dir self.size = size self.platforms = platforms self.metadata = Box(yaml.safe_load(open(metadata_path))) self.batch_size = batch_size self.num_workers = num_workers self.split_ratio = 0.8 def setup(self, stage: Literal["fit", "predict"] | None = None) -> None: # Get list of GeoTIFF filepaths from s3 bucket or data/ folder if self.data_dir.startswith("s3://"): dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.data_dir]) chips_path = list(dp.list_files_by_s3(masks="*.npz")) else: # if self.data_dir is a local data path chips_path = sorted(list(Path(self.data_dir).glob("**/*.npz"))) chips_platform = [chip.parent.parent.name for chip in chips_path] # chips_platform = [chip.parent.parent.name for chip in chips_path] print(f"Total number of chips: {len(chips_path)}") if stage == "fit": trn_paths, val_paths = train_test_split( chips_path, test_size=(1 - self.split_ratio), stratify=chips_platform, shuffle=True, ) self.trn_ds = EODataset( chips_path=trn_paths, size=self.size, platforms=self.platforms, metadata=self.metadata, ) self.trn_sampler = ClaySampler( dataset=self.trn_ds, platforms=self.platforms, batch_size=self.batch_size, ) self.val_ds = EODataset( chips_path=val_paths, size=self.size, platforms=self.platforms, metadata=self.metadata, ) self.val_sampler = ClaySampler( dataset=self.val_ds, platforms=self.platforms, batch_size=self.batch_size, ) elif stage == "predict": self.prd_ds = EODataset( chips_path=chips_path, platform=self.platform, metadata_path=self.metadata_path, ) def train_dataloader(self): return DataLoader( self.trn_ds, num_workers=self.num_workers, batch_sampler=self.trn_sampler, collate_fn=batch_collate, pin_memory=True, prefetch_factor=4, ) def val_dataloader(self): return DataLoader( self.val_ds, num_workers=self.num_workers, batch_sampler=self.val_sampler, collate_fn=batch_collate, pin_memory=True, prefetch_factor=4, ) def predict_dataloader(self): return DataLoader( dataset=self.prd_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )
./src/callbacks.py: from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks.finetuning import BaseFinetuning class ProgressiveResizing(Callback): def __init__(self): self.resize_schedule = { 0: {"batch_size": 4, "num_workers": 4, "size": 64}, 10: {"batch_size": 2, "num_workers": 2, "size": 128}, 20: {"batch_size": 1, "num_workers": 1, "size": 256}, } def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch in self.resize_schedule: params = self.resize_schedule[trainer.current_epoch] trainer.datamodule.size = params["size"] trainer.datamodule.batch_size = params["batch_size"] trainer.datamodule.num_workers = params["num_workers"] trainer.datamodule.setup(stage="fit") def on_validation_epoch_start(self, trainer, pl_module): if trainer.current_epoch in self.resize_schedule: params = self.resize_schedule[trainer.current_epoch] trainer.datamodule.size = params["size"] trainer.datamodule.batch_size = params["batch_size"] trainer.datamodule.num_workers = params["num_workers"] trainer.datamodule.setup(stage="validate") class LayerwiseFinetuning(BaseFinetuning): def __init__(self, phase, train_bn=True): """Initializes with phase & batch-norm information. Args: phase (List): Phases of fine-tuning the backbone network. train_bn (bool, optional): Trains just the batch-norm layers even when all the other layers of the network are freezed. Defaults to True. """ super().__init__() self.phase = phase self.train_bn = train_bn def freeze_before_training(self, pl_module): """Freezes the encoder before starting the training.""" self.freeze( modules=[ pl_module.model.encoder.patch_embedding, pl_module.model.encoder.transformer, ], train_bn=self.train_bn, ) def finetune_function(self, pl_module, epoch, optimizer): if epoch == self.phase: """Unfreezes the encoder for training.""" print(f"In Phase {self.phase}: Full throttle") self.unfreeze_and_add_param_group( modules=[ pl_module.model.encoder.patch_embedding, pl_module.model.encoder.transformer, ], optimizer=optimizer, train_bn=self.train_bn, ) params = list(pl_module.parameters()) active = list(filter(lambda p: p.requires_grad, params)) print(f"active: {len(active)}, all: {len(params)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment