Created
October 11, 2024 11:30
-
-
Save brunosan/7cce47b6e1ee2a75b5b78cab6f53488f to your computer and use it in GitHub Desktop.
find . -name "*.py" -print0 | xargs -0 -I {} sh -c 'echo {}: $(cat {})' >> model-all.py.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
./finetune/classify/classify.py: """ Command line interface to run the neural network model! From the project root directory, do: python classify.py fit --config configs/classify_eurosat.yaml References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from finetune.classify.eurosat_datamodule import EuroSATDataModule # noqa: F401 from finetune.classify.eurosat_model import EuroSATClassifier # noqa: F401 # %% def cli_main(): """ Command-line inteface to run Clasifier model with EuroSATDataModule. """ cli = LightningCLI(EuroSATClassifier, EuroSATDataModule) return cli # %% if __name__ == "__main__": cli_main() print("Done!") | |
./finetune/classify/factory.py: import re import torch from torch import nn from src.model import Encoder class Classifier(nn.Module): """ Classifier class uses Clay Encoder for feature extraction and a head for classification. Attributes: clay_encoder (Encoder): The encoder for feature extraction. head (nn.Sequential): The head for classification. device (torch.device): The device to run the model on. """ def __init__(self, num_classes=10, ckpt_path=None): """ Initialize the Classifier. Args: num_classes (int, optional): The number of classes for classification. Defaults to 10. ckpt_path (str, optional): Clay MAE pretrained model checkpoint path. Defaults to None. """ super().__init__() # Initialize Clay Encoder with parameters from base model. Set # mask_ratio to 0.0 & shuffle to False for downstream tasks. self.clay_encoder = Encoder( mask_ratio=0.0, patch_size=8, shuffle=False, dim=768, depth=12, heads=12, dim_head=64, mlp_ratio=4.0, ) # Simple 2 layer MLP head for classification self.head = nn.Sequential( nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.25), nn.Linear(512, num_classes), ) # Determine the device to run the model on self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) # Load Clay MAE pretrained weights for the Encoder self.load_clay_weights(ckpt_path) def load_clay_weights(self, ckpt_path): """ Load the weights for Clay MAE Encoder from a checkpoint file. Args: ckpt_path (str): Clay MAE pretrained model checkpoint path. """ # Load the checkpoint file ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt.get("state_dict") # Remove model.encoder prefix for the clay encoder state_dict = { re.sub(r"^model\.encoder\.", "", name): param for name, param in state_dict.items() if name.startswith("model.encoder") } # Copy the weights from the state dict to the encoder for name, param in self.clay_encoder.named_parameters(): if name in state_dict and param.size() == state_dict[name].size(): param.data.copy_(state_dict[name]) # Copy the weights else: print(f"No matching parameter for {name} with size {param.size()}") # Freeze clay encoder for param in self.clay_encoder.parameters(): param.requires_grad = False # Set the encoder to evaluation mode self.clay_encoder.eval() def forward(self, datacube): """ Forward pass of the Classifier. Args: datacube (torch.Tensor): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The output logits. """ # Get the embeddings from the encoder embeddings, *_ = self.clay_encoder( datacube ) # embeddings: batch x (1 + row x col) x 768 # Use only the first embedding i.e cls token embeddings = embeddings[:, 0, :] # Pass the embeddings through the head to get the logits logits = self.head(embeddings) return logits | |
./finetune/classify/eurosat_datamodule.py: import lightning as L import torch import yaml from box import Box from torch.utils.data import DataLoader from torchgeo.datasets import EuroSAT as TGEuroSAT from torchvision.transforms import v2 S2_BANDS = [ "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B08A", "B11", "B12", ] class EuroSAT(TGEuroSAT): """ Subclass of TGEuroSAT to customize the dataset loading and transformations. Args: root (str): Root directory of the dataset. split (str): Dataset split to use ('train' or 'val'). bands (list): List of spectral bands to use. transforms (callable): Transformations to apply to the samples. download (bool): If true, downloads the dataset. """ def __init__(self, root, split, bands, transforms, download): super().__init__(root, split, bands, transforms, download) def __getitem__(self, index): """ Override the __getitem__ method to apply custom transformations. Args: index (int): Index of the sample to retrieve. Returns: dict: A dictionary containing the image tensor, label, and additional metadata. """ image, label = self._load_image(index) image = torch.index_select(image, dim=0, index=self.band_indices).float() sample = { "pixels": image, "label": label, "time": torch.zeros(4), # Placeholder for time information "latlon": torch.zeros(4), # Placeholder for lat/lon information } if self.transforms is not None: sample = self.transforms(sample) return sample class EuroSATDataModule(L.LightningDataModule): """ Data module for loading and transforming the EuroSAT dataset. Args: batch_size (int): Batch size for the dataloaders. num_workers (int): Number of workers for data loading. metadata_path (str): Path to the metadata file for normalization statistics. """ def __init__(self, batch_size, num_workers, metadata_path): super().__init__() self.batch_size = batch_size self.num_workers = num_workers metadata = Box(yaml.safe_load(open(metadata_path)))["sentinel-2-l2a"] mean = list(metadata.bands.mean.values()) std = list(metadata.bands.std.values()) self.trn_tfm = v2.Compose( [ v2.RandomHorizontalFlip(), v2.RandomVerticalFlip(), v2.Normalize(mean, std), ] ) self.val_tfm = v2.Compose([v2.Normalize(mean, std)]) def setup(self, stage=None): """ Setup the datasets for training and validation. Args: stage (str): Stage of the training process ('fit', 'validate', etc.). """ if stage in {"fit", None}: self.trn_ds = EuroSAT( root="data", split="train", bands=S2_BANDS, transforms=self.trn_tfm, download=True, ) self.val_ds = EuroSAT( root="data", split="val", bands=S2_BANDS, transforms=self.val_tfm, download=True, ) def train_dataloader(self): """ Returns the DataLoader for the training dataset. Returns: DataLoader: DataLoader for the training dataset. """ return DataLoader( self.trn_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, prefetch_factor=2, ) def val_dataloader(self): """ Returns the DataLoader for the validation dataset. Returns: DataLoader: DataLoader for the validation dataset. """ return DataLoader( self.val_ds, batch_size=self.batch_size CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 2, shuffle=False, num_workers=self.num_workers, pin_memory=True, prefetch_factor=2, ) | |
./finetune/classify/eurosat_model.py: import lightning as L import torch from torch import nn, optim from torchmetrics import Accuracy from finetune.classify.factory import Classifier class EuroSATClassifier(L.LightningModule): """ LightningModule for training and evaluating a classifier on the EuroSAT dataset. Args: num_classes (int): Number of classes for classification. ckpt_path (str): Clay MAE pretrained checkpoint path. lr (float): Learning rate for the optimizer. wd (float): Weight decay for the optimizer. b1 (float): Beta1 parameter for the Adam optimizer. b2 (float): Beta2 parameter for the Adam optimizer. """ def __init__(self, num_classes, ckpt_path, lr, wd, b1, b2): # noqa: PLR0913 super().__init__() self.save_hyperparameters() self.model = Classifier(num_classes=num_classes, ckpt_path=ckpt_path) self.loss_fn = nn.CrossEntropyLoss() self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) def forward(self, datacube): """ Forward pass through the classifier. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The output logits from the classifier. """ # Wavelengths for Sentinel 2 bands of EuroSAT dataset waves = torch.tensor( [0.493, 0.56, 0.665, 0.704, 0.74, 0.783, 0.842, 0.865, 1.61, 2.19] ) gsd = torch.tensor(10.0) return self.model( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, } ) def configure_optimizers(self): """ Configure the optimizer and learning rate scheduler. Returns: dict: A dictionary containing the optimizer and learning rate scheduler. """ optimizer = optim.AdamW( [ param for name, param in self.model.named_parameters() if param.requires_grad ], lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=100, T_mult=1, eta_min=self.hparams.lr CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 100, last_epoch=-1 ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", }, } def shared_step(self, batch, batch_idx, phase): """ Perform a shared step for both training and validation. Args: batch (dict): A batch of data. batch_idx (int): The index of the batch. phase (str): The phase ('train' or 'val'). Returns: torch.Tensor: The computed loss for the batch. """ labels = batch["label"].long() logits = self(batch) loss = self.loss_fn(logits, labels) score = self.accuracy(logits, labels) self.log( f"{phase}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/score", score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch, batch_idx): """ Perform a training step. Args: batch (dict): A batch of training data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): """ Perform a validation step. Args: batch (dict): A batch of validation data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "val") | |
./finetune/segment/preprocess_data.py: r""" Chesapeake CVPR Data Processing Script ====================================== This script processes GeoTIFF files from the Chesapeake CVPR dataset to create image chips for segmentation tasks. Dataset Source: --------------- Chesapeake CVPR data from LILA: https://lila.science/datasets/chesapeakelandcover For this experiment, we will use images from NY. Notes: ------ 1. Only copy *_lc.tif & *_naip-new.tif files that we will use for our segmentation downstream task. Using s5cmd for this: https://github.com/peak/s5cmd - Train: s5cmd cp \ --no-sign-request \ --include "*_lc.tif" \ --include "*_naip-new.tif" \ "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" \ data/cvpr/files/train/ - Val: s5cmd cp \ --no-sign-request \ --include "*_lc.tif" \ --include "*_naip-new.tif" \ "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" \ data/cvpr/files/val/ 2. We will create chips of size `224 x 224` to feed them to the model, feel free to experiment with other chip sizes as well. Run the script as follows: python preprocess_data.py <data_dir> <output_dir> <chip_size> Example: python preprocess_data.py data/cvpr/files data/cvpr/ny 224 """ # noqa E501 import os import sys from pathlib import Path import numpy as np import rasterio as rio def read_and_chip(file_path, chip_size, output_dir): """ Reads a GeoTIFF file, creates chips of specified size, and saves them as numpy arrays. Args: file_path (str or Path): Path to the GeoTIFF file. chip_size (int): Size of the square chips. output_dir (str or Path): Directory to save the chips. """ os.makedirs(output_dir, exist_ok=True) with rio.open(file_path) as src: data = src.read() n_chips_x = src.width // chip_size n_chips_y = src.height // chip_size chip_number = 0 for i in range(n_chips_x): for j in range(n_chips_y): x1, y1 = i CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py chip_size, j CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py chip_size x2, y2 = x1 + chip_size, y1 + chip_size chip = data[:, y1:y2, x1:x2] chip_path = os.path.join( output_dir, f"{Path(file_path).stem}_chip_{chip_number}.npy", ) np.save(chip_path, chip) chip_number += 1 def process_files(file_paths, output_dir, chip_size): """ Processes a list of files, creating chips and saving them. Args: file_paths (list of Path): List of paths to the GeoTIFF files. output_dir (str or Path): Directory to save the chips. chip_size (int): Size of the square chips. """ for file_path in file_paths: print(f"Processing: {file_path}") read_and_chip(file_path, chip_size, output_dir) def main(): """ Main function to process files and create chips. Expects three command line arguments: - data_dir: Directory containing the input GeoTIFF files. - output_dir: Directory to save the output chips. - chip_size: Size of the square chips. """ if len(sys.argv) != 4: # noqa: PLR2004 print("Usage: python script.py <data_dir> <output_dir> <chip_size>") sys.exit(1) data_dir = Path(sys.argv[1]) output_dir = Path(sys.argv[2]) chip_size = int(sys.argv[3]) train_image_paths = list((data_dir / "train").glob("*_naip-new.tif")) val_image_paths = list((data_dir / "val").glob("*_naip-new.tif")) train_label_paths = list((data_dir / "train").glob("*_lc.tif")) val_label_paths = list((data_dir / "val").glob("*_lc.tif")) process_files(train_image_paths, output_dir / "train/chips", chip_size) process_files(val_image_paths, output_dir / "val/chips", chip_size) process_files(train_label_paths, output_dir / "train/labels", chip_size) process_files(val_label_paths, output_dir / "val/labels", chip_size) if __name__ == "__main__": main() | |
./finetune/segment/chesapeake_model.py: """ LightningModule for training and validating a segmentation model using the Segmentor class. """ import lightning as L import segmentation_models_pytorch as smp import torch import torch.nn.functional as F from torch import optim from torchmetrics.classification import F1Score, MulticlassJaccardIndex from finetune.segment.factory import Segmentor class ChesapeakeSegmentor(L.LightningModule): """ LightningModule for segmentation tasks, utilizing Clay Segmentor. Attributes: model (nn.Module): Clay Segmentor model. loss_fn (nn.Module): The loss function. iou (Metric): Intersection over Union metric. f1 (Metric): F1 Score metric. lr (float): Learning rate. """ def __init__( # # noqa: PLR0913 self, num_classes, feature_maps, ckpt_path, lr, wd, b1, b2, ): super().__init__() self.save_hyperparameters() # Save hyperparameters for checkpointing self.model = Segmentor( num_classes=num_classes, feature_maps=feature_maps, ckpt_path=ckpt_path, ) self.loss_fn = smp.losses.FocalLoss(mode="multiclass") self.iou = MulticlassJaccardIndex( num_classes=num_classes, average="weighted", ) self.f1 = F1Score( task="multiclass", num_classes=num_classes, average="weighted", ) def forward(self, datacube): """ Forward pass through the segmentation model. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The segmentation logits. """ waves = torch.tensor([0.65, 0.56, 0.48, 0.842]) # NAIP wavelengths gsd = torch.tensor(1.0) # NAIP GSD # Forward pass through the network return self.model( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, }, ) def configure_optimizers(self): """ Configure the optimizer and learning rate scheduler. Returns: dict: A dictionary containing the optimizer and scheduler configuration. """ optimizer = optim.AdamW( [ param for name, param in self.model.named_parameters() if param.requires_grad ], lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=1000, T_mult=1, eta_min=self.hparams.lr CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 100, last_epoch=-1, ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", }, } def shared_step(self, batch, batch_idx, phase): """ Shared step for training and validation. Args: batch (dict): A dictionary containing the batch data. batch_idx (int): The index of the batch. phase (str): The phase (train or val). Returns: torch.Tensor: The loss value. """ labels = batch["label"].long() outputs = self(batch) outputs = F.interpolate( outputs, size=(224, 224), mode="bilinear", align_corners=False, ) # Resize to match labels size loss = self.loss_fn(outputs, labels) iou = self.iou(outputs, labels) f1 = self.f1(outputs, labels) # Log metrics self.log( f"{phase}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/iou", iou, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/f1", f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch, batch_idx): """ Training step for the model. Args: batch (dict): A dictionary containing the batch data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The loss value. """ return self.shared_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): """ Validation step for the model. Args: batch (dict): A dictionary containing the batch data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The loss value. """ return self.shared_step(batch, batch_idx, "val") | |
./finetune/segment/chesapeake_datamodule.py: """ DataModule for the Chesapeake Bay dataset for segmentation tasks. This implementation provides a structured way to handle the data loading and preprocessing required for training and validating a segmentation model. Dataset citation: Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N. Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data. Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019). Dataset URL: https://lila.science/datasets/chesapeakelandcover """ import re from pathlib import Path import lightning as L import numpy as np import torch import yaml from box import Box from torch.utils.data import DataLoader, Dataset from torchvision.transforms import v2 class ChesapeakeDataset(Dataset): """ Dataset class for the Chesapeake Bay segmentation dataset. Args: chip_dir (str): Directory containing the image chips. label_dir (str): Directory containing the labels. metadata (Box): Metadata for normalization and other dataset-specific details. platform (str): Platform identifier used in metadata. """ def __init__(self, chip_dir, label_dir, metadata, platform): self.chip_dir = Path(chip_dir) self.label_dir = Path(label_dir) self.metadata = metadata self.transform = self.create_transforms( mean=list(metadata[platform].bands.mean.values()), std=list(metadata[platform].bands.std.values()), ) # Load chip and label file names self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npy")] self.labels = [re.sub("_naip-new_", "_lc_", chip) for chip in self.chips] def create_transforms(self, mean, std): """ Create normalization transforms. Args: mean (list): Mean values for normalization. std (list): Standard deviation values for normalization. Returns: torchvision.transforms.Compose: A composition of transforms. """ return v2.Compose( [ v2.Normalize(mean=mean, std=std), ], ) def __len__(self): return len(self.chips) def __getitem__(self, idx): """ Get a sample from the dataset. Args: idx (int): Index of the sample. Returns: dict: A dictionary containing the image, label, and additional information. """ chip_name = self.chip_dir / self.chips[idx] label_name = self.label_dir / self.labels[idx] chip = np.load(chip_name).astype(np.float32) label = np.load(label_name) # Remap labels to match desired classes label_mapping = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 15: 6} remapped_label = np.vectorize(label_mapping.get)(label) sample = { "pixels": self.transform(torch.from_numpy(chip)), "label": torch.from_numpy(remapped_label[0]), "time": torch.zeros(4), # Placeholder for time information "latlon": torch.zeros(4), # Placeholder for latlon information } return sample class ChesapeakeDataModule(L.LightningDataModule): """ DataModule class for the Chesapeake Bay dataset. Args: train_chip_dir (str): Directory containing training image chips. train_label_dir (str): Directory containing training labels. val_chip_dir (str): Directory containing validation image chips. val_label_dir (str): Directory containing validation labels. metadata_path (str): Path to the metadata file. batch_size (int): Batch size for data loading. num_workers (int): Number of workers for data loading. platform (str): Platform identifier used in metadata. """ def __init__( # noqa: PLR0913 self, train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, metadata_path, batch_size, num_workers, platform, ): super().__init__() self.train_chip_dir = train_chip_dir self.train_label_dir = train_label_dir self.val_chip_dir = val_chip_dir self.val_label_dir = val_label_dir self.metadata = Box(yaml.safe_load(open(metadata_path))) self.batch_size = batch_size self.num_workers = num_workers self.platform = platform def setup(self, stage=None): """ Setup datasets for training and validation. Args: stage (str): Stage identifier ('fit' or 'test'). """ if stage in {"fit", None}: self.trn_ds = ChesapeakeDataset( self.train_chip_dir, self.train_label_dir, self.metadata, self.platform, ) self.val_ds = ChesapeakeDataset( self.val_chip_dir, self.val_label_dir, self.metadata, self.platform, ) def train_dataloader(self): """ Create DataLoader for training data. Returns: DataLoader: DataLoader for training dataset. """ return DataLoader( self.trn_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, ) def val_dataloader(self): """ Create DataLoader for validation data. Returns: DataLoader: DataLoader for validation dataset. """ return DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, ) | |
./finetune/segment/factory.py: """ Clay Segmentor for semantic segmentation tasks. Attribution: Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation with Transformers Paper URL: https://arxiv.org/abs/2105.15203 """ import re import torch from einops import rearrange, repeat from torch import nn from src.model import Encoder class SegmentEncoder(Encoder): """ Encoder class for segmentation tasks, incorporating a feature pyramid network (FPN). Attributes: feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the clay checkpoint file. """ def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, feature_maps, ckpt_path=None, ): super().__init__( mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, ) self.feature_maps = feature_maps # Define Feature Pyramid Network (FPN) layers self.fpn1 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), nn.BatchNorm2d(dim), nn.GELU(), nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn5 = nn.Identity() # Set device self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) # Load model from checkpoint if provided self.load_from_ckpt(ckpt_path) def load_from_ckpt(self, ckpt_path): """ Load the model's state from a checkpoint file. Args: ckpt_path (str): The path to the checkpoint file. """ if ckpt_path: # Load checkpoint ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt.get("state_dict") # Prepare new state dict with the desired subset and naming new_state_dict = { re.sub(r"^model\.encoder\.", "", name): param for name, param in state_dict.items() if name.startswith("model.encoder") } # Load the modified state dict into the model model_state_dict = self.state_dict() for name, param in new_state_dict.items(): if ( name in model_state_dict and param.size() == model_state_dict[name].size() ): model_state_dict[name].copy_(param) else: print(f"No matching parameter for {name} with size {param.size()}") # Freeze the loaded parameters for name, param in self.named_parameters(): if name in new_state_dict: param.requires_grad = False def forward(self, datacube): """ Forward pass of the SegmentEncoder. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: list: A list of feature maps extracted from the datacube. """ cube, time, latlon, gsd, waves = ( datacube["pixels"], # [B C H W] datacube["time"], # [B 2] datacube["latlon"], # [B 2] datacube["gsd"], # 1 datacube["waves"], # [N] ) B, C, H, W = cube.shape # Patchify and create embeddings per patch patches, waves_encoded = self.to_patch_embed(cube, waves) # [B L D] patches = self.add_encodings(patches, time, latlon, gsd) # [B L D] # Add class tokens cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] features = [] for idx, (attn, ff) in enumerate(self.transformer.layers): patches = attn(patches) + patches patches = ff(patches) + patches if idx in self.feature_maps: _cube = rearrange( patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8 ) features.append(_cube) patches = self.transformer.norm(patches) _cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8) features.append(_cube) # Apply FPN layers ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4, self.fpn5] for i in range(len(features)): features[i] = ops[i](features[i]) return features class Segmentor(nn.Module): """ Clay Segmentor class that combines the Encoder with FPN layers for semantic segmentation. Attributes: num_classes (int): Number of output classes for segmentation. feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the checkpoint file. """ def __init__(self, num_classes, feature_maps, ckpt_path): super().__init__() # Default values are for the clay mae base model. self.encoder = SegmentEncoder( mask_ratio=0.0, patch_size=8, shuffle=False, dim=768, depth=12, heads=12, dim_head=64, mlp_ratio=4.0, feature_maps=feature_maps, ckpt_path=ckpt_path, ) self.upsamples = [nn.Upsample(scale_factor=2**i) for i in range(4)] + [ nn.Upsample(scale_factor=4), ] self.fusion = nn.Conv2d(self.encoder.dim CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 5, self.encoder.dim, kernel_size=1) self.seg_head = nn.Conv2d(self.encoder.dim, num_classes, kernel_size=1) def forward(self, datacube): """ Forward pass of the Segmentor. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The segmentation logits. """ features = self.encoder(datacube) for i in range(len(features)): features[i] = self.upsamples[i](features[i]) fused = torch.cat(features, dim=1) fused = self.fusion(fused) logits = self.seg_head(fused) return logits | |
./finetune/segment/segment.py: """ Command line interface to run the neural network model! From the project root directory, do: python segment.py fit --config configs/segment_chesapeake.yaml References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from finetune.segment.chesapeake_datamodule import ChesapeakeDataModule # noqa: F401 from finetune.segment.chesapeake_model import ChesapeakeSegmentor # noqa: F401 # %% def cli_main(): """ Command-line inteface to run Segmentation Model with ChesapeakeDataModule. """ cli = LightningCLI(ChesapeakeSegmentor, ChesapeakeDataModule) return cli # %% if __name__ == "__main__": cli_main() print("Done!") | |
./finetune/regression/preprocess_data.py: import random from multiprocessing import Pool from pathlib import Path from typing import List import click import numpy as np from tifffile import imread EXPECTED_NR_OF_FILES_PER_TILE = 24 MONTHS = [ "00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", ] def list_uniqe_ids(src: Path) -> List[str]: ids = list(set(dat.name.split("_")[0] for dat in src.glob("*.tif"))) print(f"Found {len(ids)} unique tile ids") return ids def process_data_for_id( id: str, feature_path: Path, cubes_path: Path, overwrite: bool ) -> None: if not overwrite and (cubes_path / f"biomasters_cube_{id}.npz").exists(): print(f"Found existing file for {id}, skipping.") return data = [] for month in MONTHS: data_month = [] for platform in ["S1", "S2"]: feature_name = f"{id}_{platform}_{month}.tif" if not Path(feature_path / feature_name).exists(): continue file_data = ( imread(feature_path / feature_name).swapaxes(1, 2).swapaxes(0, 1) ) ND1 = 0 ND2 = -9999 if platform == "S1": # Limit to first orbit (the other is mostly nodata) file_data = file_data[:2] file_data = np.ma.array( file_data, mask=np.logical_or(file_data == ND1, file_data == ND2) ) else: file_data = file_data[:10] file_data = np.ma.array(file_data, mask=file_data == ND1) data_month.append(file_data) data_month = np.ma.vstack(data_month) NR_OF_BANDS_EXPECTED = 12 if data_month.shape[0] != NR_OF_BANDS_EXPECTED: continue data.append(data_month) cube = np.ma.array(data) mean_cube = np.ma.mean(cube, axis=0) if np.sum(mean_cube.mask): print("Nodata", np.sum(mean_cube.mask)) NODATA_THRESHOLD = 1e5 if np.sum(mean_cube.mask) > NODATA_THRESHOLD: print("Skipping due to lots of nodata") return np.savez_compressed(cubes_path / f"biomasters_cube_{id}.npz", cube=mean_cube) @click.command() @click.option( "--features", help="Folder with features (training or test)", type=click.Path(path_type=Path), ) @click.option( "--cubes", help="Folder to write the datacubes", type=click.Path(path_type=Path) ) @click.option( "--processes", default=1, help="How many processes to use for parallel processing", type=click.INT, ) @click.option( "--sample", default=0.05, help="Fraction of original data to sample", type=click.FloatRange(0, 1), ) @click.option( "--overwrite", is_flag=True, help="Overwrite existing cubes", ) def process(features, cubes, processes, sample, overwrite): """ Combine tiff files into npz datacubes. The datacubes will contain the S1 vv/vh bands for asc and desc orbits, stacked with the first 10 S2 bands. """ ids = list_uniqe_ids(features) if sample < 1: sample_length = int(len(ids) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py sample) random.seed(42) ids = random.sample(ids, sample_length) print(f"Subsampled {len(ids)} tiles") if processes > 1: features = [features] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py len(ids) cubes = [cubes] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py len(ids) overwrite = [overwrite] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py len(ids) with Pool(processes) as pl: pl.starmap(process_data_for_id, zip(ids, features, cubes, overwrite)) else: for id in ids: process_data_for_id(id, features, cubes, overwrite) if __name__ == "__main__": process() | |
./finetune/regression/regression.py: """ Command line interface to run the neural network model! From the project root directory, do: python regression.py fit --config configs/regression_biomasters.yaml References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from finetune.regression.biomasters_datamodule import BioMastersDataModule # noqa: F401 from finetune.regression.biomasters_model import BioMastersClassifier # noqa: F401 # %% def cli_main(): """ Command-line inteface to run Regression with BioMastersDataModule. """ cli = LightningCLI( BioMastersClassifier, BioMastersDataModule, save_config_kwargs={"overwrite": True}, ) return cli # %% if __name__ == "__main__": cli_main() print("Done!") | |
./finetune/regression/factory.py: """ Clay Segmentor for semantic segmentation tasks. Attribution: Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation with Transformers Paper URL: https://arxiv.org/abs/2105.15203 """ import re import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import nn from src.model import Encoder class SegmentEncoder(Encoder): """ Encoder class for segmentation tasks, incorporating a feature pyramid network (FPN). Attributes: feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the clay checkpoint file. """ def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, feature_maps, ckpt_path=None, ): super().__init__( mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, ) self.feature_maps = feature_maps # Define Feature Pyramid Network (FPN) layers self.fpn1 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), nn.BatchNorm2d(dim), nn.GELU(), nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn5 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) # Set device self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) # Load model from checkpoint if provided self.load_from_ckpt(ckpt_path) def load_from_ckpt(self, ckpt_path): """ Load the model's state from a checkpoint file. Args: ckpt_path (str): The path to the checkpoint file. """ if ckpt_path: # Load checkpoint ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt.get("state_dict") # Prepare new state dict with the desired subset and naming new_state_dict = { re.sub(r"^model\.encoder\.", "", name): param for name, param in state_dict.items() if name.startswith("model.encoder") } # Load the modified state dict into the model model_state_dict = self.state_dict() for name, param in new_state_dict.items(): if ( name in model_state_dict and param.size() == model_state_dict[name].size() ): model_state_dict[name].copy_(param) else: print(f"No matching parameter for {name} with size {param.size()}") # Freeze the loaded parameters for name, param in self.named_parameters(): if name in new_state_dict: param.requires_grad = False def forward(self, datacube): """ Forward pass of the SegmentEncoder. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: list: A list of feature maps extracted from the datacube. """ cube, time, latlon, gsd, waves = ( datacube["pixels"], # [B C H W] datacube["time"], # [B 2] datacube["latlon"], # [B 2] datacube["gsd"], # 1 datacube["waves"], # [N] ) B, C, H, W = cube.shape # Patchify and create embeddings per patch patches, waves_encoded = self.to_patch_embed(cube, waves) # [B L D] patches = self.add_encodings(patches, time, latlon, gsd) # [B L D] # Add class tokens cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] features = [] for idx, (attn, ff) in enumerate(self.transformer.layers): patches = attn(patches) + patches patches = ff(patches) + patches if idx in self.feature_maps: _cube = rearrange( patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8 ) features.append(_cube) # patches = self.transformer.norm(patches) # _cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H//8, W=W//8) # features.append(_cube) # Apply FPN layers ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4, self.fpn5] for i in range(len(features)): features[i] = ops[i](features[i]) return features class FusionBlock(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1) self.bn = nn.BatchNorm2d(output_dim) def forward(self, x): x = F.relu(self.bn(self.conv(x))) return x class SegmentationHead(nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.conv1 = nn.Conv2d(input_dim, input_dim // 2, kernel_size=3, padding=1) self.conv2 = nn.Conv2d( input_dim // 2, num_classes, kernel_size=1 ) # final conv to num_classes self.bn1 = nn.BatchNorm2d(input_dim // 2) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = self.conv2(x) # No activation before final layer return x class Regressor(nn.Module): """ Clay Regressor class that combines the Encoder with FPN layers for semantic regression. Attributes: num_classes (int): Number of output classes for segmentation. feature_maps (list): Indices of layers to be used for generating feature maps. ckpt_path (str): Path to the checkpoint file. """ def __init__(self, num_classes, feature_maps, ckpt_path): super().__init__() # Default values are for the clay mae base model. self.encoder = SegmentEncoder( mask_ratio=0.0, patch_size=8, shuffle=False, dim=768, depth=12, heads=12, dim_head=64, mlp_ratio=4.0, feature_maps=feature_maps, ckpt_path=ckpt_path, ) self.upsamples = [nn.Upsample(scale_factor=2**i) for i in range(5)] self.fusion = FusionBlock(self.encoder.dim, self.encoder.dim // 4) self.seg_head = nn.Conv2d( self.encoder.dim // 4, num_classes, kernel_size=3, padding=1 ) def forward(self, datacube): """ Forward pass of the Regressor. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The segmentation logits. """ features = self.encoder(datacube) for i in range(len(features)): features[i] = self.upsamples[i](features[i]) # fused = torch.cat(features, dim=1) fused = torch.sum(torch.stack(features), dim=0) fused = self.fusion(fused) logits = self.seg_head(fused) return logits | |
./finetune/regression/biomasters_model.py: import lightning as L import torch import torch.nn.functional as F from torch import nn, optim from torchmetrics import MeanSquaredError from finetune.regression.factory import Regressor class NoNaNRMSE(nn.Module): def __init__(self, threshold=400): super().__init__() self.threshold = threshold def forward(self, logits, target): not_nan = target < self.threshold # logits = logits.squeeze(1) diff = logits - target diff[~not_nan] = 0 diff2 = torch.square(diff) diff2m = (diff2 / not_nan.sum((-1, -2, -3), keepdim=True)).sum((-1, -2, -3)) diff2msqrt = torch.sqrt(diff2m) rmse = diff2msqrt.mean() return rmse class BioMastersClassifier(L.LightningModule): """ LightningModule for training and evaluating a regression on the BioMasters dataset. Args: num_classes (int): Number of classes for classification. ckpt_path (str): Clay MAE pretrained checkpoint path. lr (float): Learning rate for the optimizer. wd (float): Weight decay for the optimizer. b1 (float): Beta1 parameter for the Adam optimizer. b2 (float): Beta2 parameter for the Adam optimizer. """ def __init__(self, ckpt_path, feature_maps, lr, wd, b1, b2): # noqa: PLR0913 super().__init__() self.save_hyperparameters() # self.model = Classifier(num_classes=1, ckpt_path=ckpt_path) self.model = Regressor( num_classes=1, feature_maps=feature_maps, ckpt_path=ckpt_path ) self.loss_fn = NoNaNRMSE() self.score_fn = MeanSquaredError() def forward(self, datacube): """ Forward pass through the classifier. Args: datacube (dict): A dictionary containing the input datacube and meta information like time, latlon, gsd & wavelenths. Returns: torch.Tensor: The output logits from the classifier. """ # Wavelengths for S1 and S2 bands of BioMasters dataset waves = torch.tensor( [ # 3.5, # S1 # 4.0, 0.493, # S2 0.56, 0.665, 0.704, 0.74, 0.783, 0.842, 0.865, 1.61, 2.19, ] ) gsd = torch.tensor(10.0) return self.model( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, } ) def configure_optimizers(self): """ Configure the optimizer and learning rate scheduler. Returns: dict: A dictionary containing the optimizer and learning rate scheduler. """ optimizer = optim.AdamW( [ param for name, param in self.model.named_parameters() if param.requires_grad ], lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "epoch", }, } def shared_step(self, batch, batch_idx, phase): """ Perform a shared step for both training and validation. Args: batch (dict): A batch of data. batch_idx (int): The index of the batch. phase (str): The phase ('train' or 'val'). Returns: torch.Tensor: The computed loss for the batch. """ labels = batch["label"] logits = self(batch) logits = F.interpolate( logits, size=(256, 256), mode="bilinear", align_corners=False, ) # Resize to match labels size # print("Logits shape", logits.shape) # print("Labels shape", labels.shape) loss = self.loss_fn(logits, labels) score = self.score_fn(logits, labels) # Convert to RMSE score = torch.sqrt(score) self.log( f"{phase}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( f"{phase}/score", score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch, batch_idx): """ Perform a training step. Args: batch (dict): A batch of training data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): """ Perform a validation step. Args: batch (dict): A batch of validation data. batch_idx (int): The index of the batch. Returns: torch.Tensor: The computed loss for the batch. """ return self.shared_step(batch, batch_idx, "val") | |
./finetune/regression/biomasters_datamodule.py: """ DataModule for the BioMasters dataset for a regression task. BioMassters: A Benchmark Dataset for Forest Biomass Estimation using Multi-modal Satellite Time-series https://nascetti-a.github.io/BioMasster/ This implementation provides a structured way to handle the data loading and preprocessing required for training and validating a regression model. Citation: Andrea Nascetti, Ritu Yadav, Kirill Brodt, Qixun Qu, Hongwei Fan, Yuri Shendryk, Isha Shah, and Christine Chung, BioMassters: A Benchmark Dataset for Forest Biomass Estimation using Multi-modal Satellite Time-series, Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track, 2023, https://openreview.net/forum?id=hrWsIC4Cmz """ from pathlib import Path import lightning as L import numpy as np import torch import yaml from box import Box from tifffile import imread from torch.utils.data import DataLoader, Dataset from torchvision.transforms import v2 class BioMastersDataset(Dataset): """ Dataset class for the BioMasters regression dataset. Assumes band order vv, vh, vv, vh, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12 Args: chip_dir (str): Directory containing the image chips. label_dir (str): Directory containing the labels. """ def __init__(self, chip_dir, label_dir, metadata): self.chip_dir = Path(chip_dir) self.label_dir = Path(label_dir) self.metadata = metadata # Load statistics from Clay metadata s2_mean = list(metadata["sentinel-2-l2a"].bands.mean.values()) s2_std = list(metadata["sentinel-2-l2a"].bands.std.values()) # Duplicate the S1 statistics so that the asc/desc orbit data # is handled correctly self.transform = self.create_transforms( mean=s2_mean, std=s2_std, ) # Load chip and label file names self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npz")] print(f"Found {len(self.chips)} chips to process for {chip_dir}") def create_transforms(self, mean, std): """ Create normalization transforms. Args: mean (list): Mean values for normalization. std (list): Standard deviation values for normalization. Returns: torchvision.transforms.Compose: A composition of transforms. """ return v2.Compose( [ v2.Normalize(mean=mean, std=std), ], ) def __len__(self): return len(self.chips) def __getitem__(self, idx): """ Get a sample from the dataset. Args: idx (int): Index of the sample. Returns: dict: A dictionary containing the image, label, and additional information. """ chip_name = self.chip_dir / self.chips[idx] label_name = self.label_dir / (chip_name.stem.split("_")[-1] + "_agbm.tif") chip = np.load(chip_name)["cube"][2:, ...].astype("float32") label = imread(label_name).astype("float32") label = np.expand_dims(label, 0) sample = { "pixels": self.transform(torch.from_numpy(chip)), "label": torch.from_numpy(label), "time": torch.zeros(4), # Placeholder for time information "latlon": torch.zeros(4), # Placeholder for latlon information } return sample class BioMastersDataModule(L.LightningDataModule): """ DataModule class for the Chesapeake Bay dataset. Args: train_chip_dir (str): Directory containing training image chips. train_label_dir (str): Directory containing training labels. val_chip_dir (str): Directory containing validation image chips. val_label_dir (str): Directory containing validation labels. metadata_path (str): Path to the metadata file. batch_size (int): Batch size for data loading. num_workers (int): Number of workers for data loading. platform (str): Platform identifier used in metadata. """ def __init__( # noqa: PLR0913 self, train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, metadata_path, batch_size, num_workers, ): super().__init__() self.train_chip_dir = train_chip_dir self.train_label_dir = train_label_dir self.val_chip_dir = val_chip_dir self.val_label_dir = val_label_dir self.metadata = Box(yaml.safe_load(open(metadata_path))) self.batch_size = batch_size self.num_workers = num_workers def setup(self, stage=None): """ Setup datasets for training and validation. Args: stage (str): Stage identifier ('fit' or 'test'). """ if stage in {"fit", None}: self.trn_ds = BioMastersDataset( self.train_chip_dir, self.train_label_dir, self.metadata, ) self.val_ds = BioMastersDataset( self.val_chip_dir, self.val_label_dir, self.metadata, ) def train_dataloader(self): """ Create DataLoader for training data. Returns: DataLoader: DataLoader for training dataset. """ return DataLoader( self.trn_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, ) def val_dataloader(self): """ Create DataLoader for validation data. Returns: DataLoader: DataLoader for validation dataset. """ return DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, ) | |
./trainer.py: """ Command line interface to run the neural network model! From the project root directory, do: python trainer.py fit References: - https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html - https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6 """ from lightning.pytorch.cli import LightningCLI from src.datamodule import ClayDataModule # noqa: F401 from src.model import ClayMAEModule # noqa: F401 # %% def cli_main(): """ Command-line inteface to run ClayMAE with ClayDataModule. """ cli = LightningCLI(save_config_kwargs={"overwrite": True}) return cli # %% if __name__ == "__main__": cli_main() print("Done!") | |
./src/callbacks_wandb.py: """ Lightning callback functions for logging to Weights & Biases. Includes a way to visualize RGB images derived from the raw logits of a Masked Autoencoder's decoder during the validation loop. I.e. to see if the Vision Transformer model is learning how to do image reconstruction. Usage: ``` import lightning as L from src.callbacks_wandb import LogMAEReconstruction trainer = L.Trainer( ..., callbacks=[LogMAEReconstruction(num_samples=6)] ) ``` References: - https://lightning.ai/docs/pytorch/2.1.0/common/trainer.html#callbacks - https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY - https://github.com/ashleve/lightning-hydra-template/blob/wandb-callbacks/src/callbacks/wandb_callbacks.py#L245 """ import lightning as L import matplotlib.pyplot as plt import numpy as np import skimage import torch from einops import rearrange try: import wandb except ImportError: wandb = None # %% def get_wandb_logger(trainer: L.Trainer) -> L.pytorch.loggers.WandbLogger: """ Safely get Weights & Biases logger from Trainer. """ if trainer.fast_dev_run: raise Exception( "Cannot use wandb callbacks since pytorch lightning disables " "loggers in `fast_dev_run=true` mode." ) for logger in trainer.loggers: if isinstance(logger, L.pytorch.loggers.WandbLogger): return logger break raise Exception( "You are using wandb related callback, " "but WandbLogger was not found for some reason..." ) class LogMAEReconstruction(L.Callback): """ Logs reconstructed RGB images from a Masked Autoencoder's decoder to WandB. """ def __init__(self, num_samples: int = 8): """ Define how many sample images to log. Parameters ---------- num_samples : int The number of RGB image samples to upload to WandB. Default is 8. """ super().__init__() self.num_samples: int = num_samples self.ready: bool = False if wandb is None: raise ModuleNotFoundError( "Package `wandb` is required to be installed to use this callback. " "Please use `pip install wandb` or " "`conda install -c conda-forge wandb` " "to install the package" ) def on_sanity_check_start(self, trainer, pl_module): """ Don't execute callback before validation sanity checks are completed. """ self.ready = False def on_sanity_check_end(self, trainer, pl_module): """ Start executing callback only after all validation sanity checks end. """ self.ready = True def on_validation_batch_end( self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: dict[str, torch.Tensor], batch: dict[str, torch.Tensor | list[str]], batch_idx: int, ) -> list: """ Called in the validation loop at the start of every mini-batch. Gather a sample of data from the first mini-batch, get the RGB bands, apply histogram equalization to the image, and log it to WandB. """ if self.ready and batch_idx == 0: # only run on first mini-batch with torch.inference_mode(): # Get WandB logger self.logger = get_wandb_logger(trainer=trainer) # Turn raw logits into reconstructed 512x512 images patchified_pixel_values: torch.Tensor = outputs["logits"] # assert patchified_pixel_values.shape == torch.Size([32, 64, 53248]) y_hat: torch.Tensor = pl_module.vit.unpatchify( patchified_pixel_values=patchified_pixel_values ) # assert y_hat.shape == torch.Size([32, 13, 512, 512]) # Reshape tensors from channel-first to channel-last x: torch.Tensor = torch.einsum( "bchw->bhwc", batch["image"][: self.num_samples] ) y_hat: torch.Tensor = torch.einsum( "bchw->bhwc", y_hat[: self.num_samples] ) # assert y_hat.shape == torch.Size([8, 512, 512, 13]) assert x.shape == y_hat.shape # Plot original and reconstructed RGB images of Sentinel-2 rgb_original: np.ndarray = ( x[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy() ) rgb_reconstruction: np.ndarray = ( y_hat[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy() ) figures: list[wandb.Image] = [] for i in range(min(x.shape[0], self.num_samples)): img_original = wandb.Image( data_or_path=skimage.exposure.equalize_hist( image=rgb_original[i] ), caption=f"RGB Image {i}", ) figures.append(img_original) img_reconstruction = wandb.Image( data_or_path=skimage.exposure.equalize_hist( image=rgb_reconstruction[i] ), caption=f"Reconstructed {i}", ) figures.append(img_reconstruction) # Upload figures to WandB self.logger.experiment.log(data={"Examples": figures}) return figures class LogIntermediatePredictions(L.Callback): """Visualize the model results at the end of every epoch.""" def __init__(self): """ Instantiates with wandb-logger. """ super().__init__() def on_validation_end( self, trainer: L.Trainer, pl_module: L.LightningModule, ) -> None: """ Called when the validation loop ends. At the end of each epoch, takes the first batch from validation dataset & logs the model predictions to wandb-logger for humans to interpret how model evolves over time. """ with torch.no_grad(): # Get WandB logger self.logger = get_wandb_logger(trainer=trainer) # get the val dataloader val_dl = iter(trainer.val_dataloaders) for i in range(6): batch = next(val_dl) platform = batch["platform"][0] batch = { k: v.to(pl_module.device) for k, v in batch.items() if isinstance(v, torch.Tensor) } waves = torch.tensor( list( trainer.datamodule.metadata[platform].bands.wavelength.values() ) ) gsd = torch.tensor(trainer.datamodule.metadata[platform].gsd) # ENCODER ( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) = pl_module.model.encoder( { "pixels": batch["pixels"], "time": batch["time"], "latlon": batch["latlon"], "gsd": gsd, "waves": waves, } ) # DECODER pixels, waves = pl_module.model.decoder( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, batch["time"], batch["latlon"], gsd, waves, ) # pixels: batch x (patch x patch) x 1024 pixels = rearrange( pixels, "b (h w) (c p1 p2) -> b c (h p1) (w p2)", p1=pl_module.model.patch_size, p2=pl_module.model.patch_size, h=trainer.datamodule.size // pl_module.model.patch_size, w=trainer.datamodule.size // pl_module.model.patch_size, ) assert pixels.shape == batch["pixels"].shape n_rows = 4 # 2 for actual and 2 for predicted n_cols = 8 fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 8)) for j in range(n_cols): # Plot actual images in rows 0 and 2 axs[0, j].imshow( batch["pixels"][j][0].detach().cpu().numpy(), cmap="viridis" ) axs[0, j].set_title(f"Actual {j}") axs[0, j].axis("off") axs[2, j].imshow( batch["pixels"][j + n_cols][0].detach().cpu().numpy(), cmap="viridis", ) axs[2, j].set_title(f"Actual {j+n_cols}") axs[2, j].axis("off") # Plot predicted images in rows 1 and 3 axs[1, j].imshow( pixels[j][0].detach().cpu().numpy(), cmap="viridis" ) axs[1, j].set_title(f"Pred {j}") axs[1, j].axis("off") axs[3, j].imshow( pixels[j + n_cols][0].detach().cpu().numpy(), cmap="viridis" ) axs[3, j].set_title(f"Pred {j+n_cols}") axs[3, j].axis("off") self.logger.experiment.log({f"{platform}": wandb.Image(fig)}) plt.close(fig) | |
./src/factory.py: """Dynamic Embedding from DOFA paper. Reference: - https://arxiv.org/abs/2403.15356 - https://github.com/zhu-xlab/DOFA """ import torch import torch.nn.functional as F from einops import rearrange from torch import nn from src.utils import posemb_sincos_1d class FCBlock(nn.Module): def __init__(self, size): super().__init__() self.l1 = nn.Linear(size, size) self.l2 = nn.Linear(size, size) def forward(self, x): y = F.gelu(self.l1(x)) y = F.gelu(self.l2(y)) return x + y class WavesTransformer(nn.Module): def __init__( # noqa: PLR0913 self, wave_dim, output_dim, num_latent_tokens, embed_dim, is_decoder, num_heads=4, num_layers=1, ): super().__init__() self.num_latent_tokens = num_latent_tokens self.is_decoder = is_decoder layer = nn.TransformerEncoderLayer( d_model=wave_dim, nhead=num_heads, activation="gelu", dropout=0, norm_first=False, batch_first=False, ) self.encoder = nn.TransformerEncoder(layer, num_layers) self.fc_weight = nn.Linear(wave_dim, output_dim) self.fc_bias = None if self.is_decoder else nn.Linear(wave_dim, embed_dim) self.weight_tokens = nn.Parameter( torch.randn(self.num_latent_tokens, wave_dim) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02 ) self.bias_token = nn.Parameter(torch.randn(1, wave_dim) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02) def forward(self, x): x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0) out = self.encoder(x) weights = self.fc_weight( out[self.num_latent_tokens : -1] + x[self.num_latent_tokens : -1] ) bias = None if self.is_decoder else self.fc_bias(out[-1]) return weights, bias class DynamicEmbedding(nn.Module): def __init__( self, wave_dim, num_latent_tokens, patch_size, embed_dim, is_decoder=False, ): super().__init__() self.wave_dim = wave_dim self.num_latent_tokens = num_latent_tokens self.patch_size = patch_size self.embed_dim = embed_dim self.is_decoder = is_decoder self.output_dim = (patch_size**2) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py embed_dim self.weight_generator = WavesTransformer( wave_dim, self.output_dim, self.num_latent_tokens, self.embed_dim, is_decoder, ) self.fclayer = FCBlock(self.wave_dim) self.initialize_weights() def forward(self, batch, waves): waves = posemb_sincos_1d(waves, self.wave_dim) waves = waves.to(batch.device) waves = self.fclayer(waves) weight, bias = self.weight_generator(waves) if self.is_decoder: dynamic_weight = rearrange( weight, "cin (k1 k2 cout) -> (cin k1 k2) cout", k1=self.patch_size, k2=self.patch_size, cout=self.embed_dim, ) if bias is not None: bias = rearrange(bias, "b -> (b)") dynamic_out = F.linear(batch, dynamic_weight CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02, bias=bias) x = dynamic_out else: dynamic_weight = rearrange( weight, "cin (cout k1 k2) -> cout cin k1 k2", k1=self.patch_size, k2=self.patch_size, ) if bias is not None: bias = rearrange(bias, "b -> (b)") dynamic_out = F.conv2d( batch, dynamic_weight CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02, bias=bias, stride=self.patch_size ) x = rearrange(dynamic_out, "b c h w -> b (h w) c") return x, waves def initialize_weights(self): # Initialize weights using Xavier initialization for m in self.modules(): if isinstance(m, (nn.Linear, nn.Conv2d)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) | |
./src/model.py: import math import os from typing import Literal import lightning as L import timm import torch import torch.nn.functional as F import yaml from box import Box from einops import rearrange, reduce, repeat from torch import nn from torchvision.transforms import v2 from vit_pytorch.simple_vit import Transformer from src.factory import DynamicEmbedding from src.utils import posemb_sincos_2d_with_gsd torch.set_float32_matmul_precision("medium") os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1" class Encoder(nn.Module): def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, shuffle, dim, depth, heads, dim_head, mlp_ratio, ): super().__init__() self.mask_ratio = mask_ratio self.patch_size = patch_size self.shuffle = shuffle self.dim = dim self.cls_token = nn.Parameter(torch.randn(1, 1, dim) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.02) self.patch_embedding = DynamicEmbedding( wave_dim=128, num_latent_tokens=128, patch_size=patch_size, embed_dim=dim, is_decoder=False, ) self.transformer = Transformer( dim=dim, depth=depth, heads=heads, dim_head=dim_head, mlp_dim=int(dim CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py mlp_ratio), ) def to_patch_embed(self, cube, waves): """Split the input cube into patches & create embeddings per patch""" patches, waves_encoded = self.patch_embedding(cube, waves) # [B L D] return patches, waves_encoded # ([B L D], [N D]) def add_encodings(self, patches, time, latlon, gsd): """Add position encoding to the patches""" B, L, D = patches.shape grid_size = int(math.sqrt(L)) self.num_patches = grid_size**2 pos_encoding = ( posemb_sincos_2d_with_gsd( h=grid_size, w=grid_size, dim=(self.dim - 8), gsd=gsd, ) .to(patches.device) .detach() ) # [L (D - 8)] time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8] pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] pos_metadata_encoding = torch.cat( (pos_encoding, time_latlon), dim=-1 ) # [B L D] patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D] return patches # [B L D] def mask_out(self, patches): """ Mask out patches randomly by shuffling the patches & masking out the first N patches Parameters ---------- patches : torch.Tensor A tensor of shape (B, L, D) Returns ------- unmasked_patches : torch.Tensor A tensor of shape (B, L:(1 - mask_ratio), D) containing the embeddings of the unmasked patches. unmasked_indices : torch.Tensor A tensor of shape (B, (1 - mask_ratio)) containing the indices of the unmasked patches. masked_indices : torch.Tensor A tensor of shape (B, mask_ratio) containing the indices of the masked patches. masked_matrix : torch.Tensor A tensor of shape (B, L) containing the mask matrix, 1 indicates a masked patch & 0 indicates an unmasked patch. """ B, L, D = patches.shape # assert ( # L == self.num_patches # ), f"Expected {self.num_patches} patches, got {L} patches." if self.shuffle: # Shuffle the patches noise = torch.randn((B, L), device=patches.device) # [B L] else: # Don't shuffle, useful for interpolation & inspection of embeddings noise = rearrange( torch.arange(B CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L, device=patches.device), "(B L) -> B L", B=B, L=L ) random_indices = torch.argsort(noise, dim=-1) # [B L] reverse_indices = torch.argsort(random_indices, dim=-1) # [B L] num_masked_patches = int( self.mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py self.num_patches ) # Number of patches to be masked out masked_indices, unmasked_indices = ( random_indices[:, :num_masked_patches], # [B mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L] random_indices[:, num_masked_patches:], # [B (1 - mask_ratio) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L] ) # create a mask of shape B L, where 1 indicates a masked patch # and 0 indicates an unmasked patch masked_matrix = torch.zeros((B, L), device=patches.device) # [B L] = 0 masked_matrix[:, :num_masked_patches] = 1 # [B mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py L] = 1 masked_matrix = torch.gather( masked_matrix, dim=1, index=reverse_indices ) # [B L] -> [B L] - reorder the patches # mask out the patches batch_indices = rearrange( torch.arange(B, device=patches.device), "B -> B 1" ) # [B 1] unmasked_patches = patches[ batch_indices, unmasked_indices, : ] # [B L:(1 - mask_ratio) D] _ = patches[batch_indices, masked_indices, :] # [B L:mask_ratio D] return ( unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] def forward(self, datacube): cube, time, latlon, gsd, waves = ( datacube["pixels"], # [B C H W] datacube["time"], # [B 2] datacube["latlon"], # [B 2] datacube["gsd"], # 1 datacube["waves"], # [N] ) # [B C H W] B, C, H, W = cube.shape patches, waves_encoded = self.to_patch_embed( cube, waves ) # [B L D] - patchify & create embeddings per patch # TODO: Add time & latlon as encoding to patches patches = self.add_encodings( patches, time, latlon, gsd, ) # [B L D] - add position encoding to the embeddings # mask out patches ( unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) = self.mask_out( patches ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] # Add class tokens cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] unmasked_patches = torch.cat( (cls_tokens, unmasked_patches), dim=1 ) # [B (1 + L) D] # pass the unmasked patches through the transformer encoded_unmasked_patches = self.transformer( unmasked_patches ) # [B ((1 + L)):(1 - mask_ratio)) D] return ( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) # [B ((1 + L):(1 - mask_ratio)) D], [(1-mask_ratio)], [mask_ratio], [B L] class Decoder(nn.Module): def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, encoder_dim, dim, depth, heads, dim_head, mlp_ratio, ): super().__init__() self.mask_ratio = mask_ratio self.patch_size = patch_size self.encoder_dim = encoder_dim self.dim = dim self.enc_to_dec = ( nn.Linear(encoder_dim, dim) if encoder_dim != dim else nn.Identity() ) self.mask_patch = nn.Parameter(torch.randn(dim)) self.transformer = Transformer( dim=dim, depth=depth, heads=heads, dim_head=dim_head, mlp_dim=int(dim CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py mlp_ratio), ) self.embed_to_pixels = DynamicEmbedding( wave_dim=128, num_latent_tokens=128, patch_size=patch_size, embed_dim=dim, is_decoder=True, ) def reconstruct_and_add_encoding( # noqa: PLR0913 self, unmasked_patches, unmasked_indices, masked_indices, masked_matrix, time, latlon, gsd, ): B, L = masked_matrix.shape grid_size = int(math.sqrt(L)) self.num_patches = grid_size**2 cls_tokens, unmasked_patches = ( unmasked_patches[:, :1, :], unmasked_patches[:, 1:, :], ) # [B 1 D], [B L:(1 - mask_ratio) D] pos_encoding = ( posemb_sincos_2d_with_gsd( h=grid_size, w=grid_size, dim=(self.dim - 8), gsd=gsd ) .to(unmasked_patches.device) .detach() ) # [L D] time_latlon = ( torch.hstack((time, latlon)).to(unmasked_patches.device).detach() ) # [B 8] pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] pos_metadata_encoding = torch.cat( (pos_encoding, time_latlon), dim=-1 ) # [B L D] batch_indices = rearrange( torch.arange(B, device=unmasked_patches.device), "B -> B 1" ) # [B 1] num_masked_patches = int(self.mask_ratio CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py self.num_patches) masked_patches = repeat( self.mask_patch, "D -> B L D", B=B, L=num_masked_patches ) # [B L:mask_ratio D] # Add position encoding masked_patches = ( masked_patches + pos_metadata_encoding[batch_indices, masked_indices, :] ) # [B L:mask_ratio D] + [B L:mask_ratio D] unmasked_patches = ( unmasked_patches + pos_metadata_encoding[batch_indices, unmasked_indices, :] ) # [B GL:(1 - masked_ratio) D] + [B GL:(1 - mask_ratio) D] # Concatenate the masked & unmasked patches decoder_patches = torch.zeros( (B, self.num_patches, self.dim), device=unmasked_patches.device ) # [B L D] decoder_patches[batch_indices, unmasked_indices, :] = ( unmasked_patches # [B L:(1 - mask_ratio) D]) ) decoder_patches[batch_indices, masked_indices, :] = ( masked_patches # [B L:mask_ratio D]) ) decoder_patches = torch.cat( (cls_tokens, decoder_patches), dim=1 ) # [B (1 + L) D] return decoder_patches # [B (1 + L) D] def forward( # noqa: PLR0913 self, encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, time, latlon, gsd, waves, ): # Change the embedding dimension from encoder to decoder encoded_unmasked_patches = self.enc_to_dec( encoded_unmasked_patches ) # [B (1 + L) D] # Reconstruct the patches to feed into the decoder transformer decoder_patches = self.reconstruct_and_add_encoding( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, time, latlon, gsd, ) # [B (1 + L) D] # Pass the decoder patches through the transformer decoded_patches = self.transformer(decoder_patches) # [B (1 + L) D] pixels, waves = self.embed_to_pixels( decoded_patches, waves ) # [B (1 + L) (C P P)] # Remove the class token pixels = pixels[:, 1:, :] return pixels, waves # [B L (C P P)], [B N] class ClayMAE(nn.Module): def __init__( # noqa: PLR0913 self, mask_ratio, patch_size, norm_pix_loss, shuffle, metadata, teacher, # ENCODER dim, depth, heads, dim_head, mlp_ratio, # DECODER decoder_dim, decoder_depth, decoder_heads, decoder_dim_head, decoder_mlp_ratio, **kwargs, ): super().__init__() self.mask_ratio = mask_ratio self.patch_size = patch_size self.norm_pix_loss = norm_pix_loss self.shuffle = shuffle self.metadata = metadata self.teacher = timm.create_model(teacher, pretrained=True, num_classes=0) self.teacher_chip_size = 224 self.teacher_resize = v2.Resize( size=(self.teacher_chip_size, self.teacher_chip_size) ) self.proj = nn.Linear(dim, self.teacher.num_features) self.encoder = Encoder( mask_ratio=mask_ratio, patch_size=patch_size, shuffle=shuffle, dim=dim, depth=depth, heads=heads, dim_head=dim_head, mlp_ratio=mlp_ratio, ) self.decoder = Decoder( mask_ratio=mask_ratio, patch_size=patch_size, encoder_dim=dim, dim=decoder_dim, depth=decoder_depth, heads=decoder_heads, dim_head=decoder_dim_head, mlp_ratio=decoder_mlp_ratio, ) self.freeze_teacher() def freeze_teacher(self): for param in self.teacher.parameters(): param.requires_grad = False def per_pixel_loss(self, cube, pixels, masked_matrix): """ cube: [B C H W] pixels: [B L (C P P)] masked_matrix: [B L], 0 is unmasked, 1 is masked """ patches = rearrange( cube, "B C (h p1) (w p2) -> B (h w) (C p1 p2)", p1=self.patch_size, p2=self.patch_size, ) # [B L (C P P)] if self.norm_pix_loss: mean = patches.mean(dim=-1, keepdim=True) var = patches.var(dim=-1, keepdim=True) patches = (patches - mean) / (var + 1e-6) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 0.5 loss = F.l1_loss(patches, pixels, reduction="none") # loss per pixel loss = reduce(loss, "B L D -> B L", reduction="mean") # loss per patch loss = ( loss CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py masked_matrix ).sum() / masked_matrix.sum() # loss on masked patches only return loss def forward(self, datacube): """ datacube: dict containing the following keys: - pixels: [B C H W] - time: [B 4] # week hour - latlon: [B 4] # lat lon - platform: [B 1] - date: [B 1] """ platform = datacube["platform"][0] waves = torch.tensor(list(self.metadata[platform].bands.wavelength.values())) gsd = torch.tensor(self.metadata[platform].gsd) # ENCODER ( encoded_unmasked_patches, # [B (1 + L):(1 - mask_ratio) D] unmasked_indices, # [(1-mask_ratio)] masked_indices, # [mask_ratio] masked_matrix, # [B L] ) = self.encoder( { "pixels": datacube["pixels"], "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, "waves": waves, } ) # DECODER pixels, waves = self.decoder( encoded_unmasked_patches, unmasked_indices, masked_indices, masked_matrix, datacube["time"], datacube["latlon"], gsd, waves, ) # [B L (C P P)] # LOSS reconstruction_loss = self.per_pixel_loss( datacube["pixels"], pixels, masked_matrix ) # TEACHER encoder_output = self.proj(encoded_unmasked_patches[:, 0, :]) # [B D'] with torch.no_grad(): if platform == "sentinel-1-rtc": r = datacube["pixels"][:, 0, :, :] g = datacube["pixels"][:, 1, :, :] b = r - g rgb = torch.stack((r, g, b), dim=1) else: # Read RGB bands from the sensor to feed the teacher model indices = self.metadata[platform].rgb_indices rgb = datacube["pixels"][:, indices, :, :] rgb = self.teacher_resize(rgb) teacher_output = self.teacher(rgb) representation_loss = -( F.cosine_similarity(encoder_output, teacher_output).mean() - 1.0 # change range from [-1, 1] to [-2, 0] ) # negative cosine similarity, [0, 2] -> 0 is similar & 2 is opposite loss = 0.90 CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py reconstruction_loss + 0.10 CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py representation_loss return (loss, reconstruction_loss, representation_loss) def clay_mae_tiny(**kwargs): args = { # ENCODER "dim": 192, "depth": 6, "heads": 4, "dim_head": 48, "mlp_ratio": 2, # DECODER "decoder_dim": 96, "decoder_depth": 3, "decoder_heads": 2, "decoder_dim_head": 48, "decoder_mlp_ratio": 2, } args.update(kwargs) return ClayMAE(**args) def clay_mae_small(**kwargs): args = { # ENCODER "dim": 384, "depth": 6, "heads": 6, "dim_head": 64, "mlp_ratio": 2, # DECODER "decoder_dim": 192, "decoder_depth": 4, "decoder_heads": 4, "decoder_dim_head": 64, "decoder_mlp_ratio": 2, } args.update(kwargs) return ClayMAE(**args) def clay_mae_base(**kwargs): args = { # ENCODER "dim": 768, "depth": 12, "heads": 12, "dim_head": 64, "mlp_ratio": 4, # DECODER "decoder_dim": 512, "decoder_depth": 6, "decoder_heads": 6, "decoder_dim_head": 64, "decoder_mlp_ratio": 4, } args.update(kwargs) return ClayMAE(**args) def clay_mae_large(**kwargs): args = { # ENCODER "dim": 1024, "depth": 24, "heads": 16, "dim_head": 64, "mlp_ratio": 4, # DECODER "decoder_dim": 512, "decoder_depth": 8, "decoder_heads": 8, "decoder_dim_head": 64, "decoder_mlp_ratio": 4, } args.update(kwargs) return ClayMAE(**args) class ClayMAEModule(L.LightningModule): def __init__( # noqa: PLR0913 self, model_size="base", mask_ratio=0.75, norm_pix_loss=False, patch_size=16, shuffle=False, metadata_path="configs/metadata.yaml", teacher="vit_base_patch16_224.dino", lr=1e-4, wd=0.05, b1=0.9, b2=0.95, embeddings_level: Literal["mean", "patch", "group"] = "mean", ): super().__init__() self.save_hyperparameters(logger=True) self.metadata = Box(yaml.safe_load(open(metadata_path))) model_map = { "tiny": clay_mae_tiny, "small": clay_mae_small, "base": clay_mae_base, "large": clay_mae_large, } if model_size in model_map: model_args = { "mask_ratio": mask_ratio, "patch_size": patch_size, "norm_pix_loss": norm_pix_loss, "shuffle": shuffle, "metadata": self.metadata, "teacher": teacher, } self.model = model_map[model_size](**model_args) else: raise ValueError( f"Invalid model size {model_size}. Expected one of {model_map.keys()}" ) def on_train_epoch_start(self): self.model.teacher.eval() def forward(self, datacube: dict[str, torch.Tensor]): return self.model(datacube) def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.b1, self.hparams.b2), ) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=1000, T_mult=2, eta_min=self.hparams.lr CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py 100, last_epoch=-1 ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", }, } def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, phase: str): datacube = batch loss, reconstruction_loss, representation_loss = self(datacube) self.log( name=f"{phase}/loss", value=loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( name=f"{phase}/rec_loss", value=reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) self.log( name=f"{phase}/rep_loss", value=representation_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): return self.shared_step(batch, batch_idx, phase="train") def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int): return self.shared_step(batch, batch_idx, phase="val") | |
./src/utils.py: """ Code from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py """ import torch def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] x = x.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) def posemb_sincos_2d_with_gsd( h, w, dim, gsd=1.0, temperature: int = 10000, dtype=torch.float32 ): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py (2 CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega / dim)) CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py (gsd / 1.0) # Adjusted for g y = y.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] x = x.flatten()[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): assert ( dim % 2 == 0 ), "Feature dimension must be a multiple of 2 for sincos embedding" pos = torch.arange(pos) if isinstance(pos, int) else pos omega = torch.arange(dim // 2) / (dim // 2 - 1) omega = 1.0 / (temperature**omega) scaled_pos = pos[:, None] CODE_OF_CONDUCT.md LICENSE LICENSE-MODEL.md README.md conda-lock.yml configs docs environment.yml finetune model-all.py.txt ruff.toml src train_clay.sh trainer.py omega[None, :] pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) return pe.type(dtype) | |
./src/datamodule.py: """ LightningDataModule to load Earth Observation data from GeoTIFF files using rasterio. """ from collections import defaultdict from pathlib import Path from typing import List, Literal import lightning as L import numpy as np import torch import torchdata import yaml from box import Box from einops import rearrange from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import Sampler from torchvision.transforms import v2 class EODataset(Dataset): """Reads different Earth Observation data sources from a directory.""" def __init__( self, chips_path: List[Path], size: int, platforms: list, metadata: Box ) -> None: super().__init__() self.chips_path = chips_path self.size = size self.transforms = {} # Generate transforms for each platform using a helper function for platform in platforms: mean = list(metadata[platform].bands.mean.values()) std = list(metadata[platform].bands.std.values()) self.transforms[platform] = self.create_transforms(mean, std) def create_transforms(self, mean, std): return v2.Compose( [ v2.RandomHorizontalFlip(p=0.5), v2.RandomVerticalFlip(p=0.5), v2.RandomCrop(size=(self.size, self.size)), v2.Normalize(mean=mean, std=std), ] ) def __len__(self): return len(self.chips_path) def __getitem__(self, idx): chip_path = self.chips_path[idx] with np.load(chip_path, allow_pickle=False) as chip: pixels = torch.from_numpy(chip["pixels"].astype(np.float32)) platform = chip_path.parent.name pixels = self.transforms[platform](pixels) # Prepare additional information additional_info = { "platform": platform, "time": torch.tensor( np.hstack((chip["week_norm"], chip["hour_norm"])), dtype=torch.float32, ), "latlon": torch.tensor( np.hstack((chip["lat_norm"], chip["lon_norm"])), dtype=torch.float32 ), } return {"pixels": pixels, **additional_info} class ClaySampler(Sampler): def __init__(self, dataset, platforms, batch_size): self.dataset = dataset self.platforms = platforms self.batch_size = batch_size self.cubes_per_platform = {platform: [] for platform in platforms} for idx, chip_path in enumerate(self.dataset.chips_path): platform = chip_path.parent.name self.cubes_per_platform[platform].append(idx) def __iter__(self): cubes_per_platform_per_epoch = {} rng = np.random.default_rng() # Shuffle and adjust sizes max_len = max(len(indices) for indices in self.cubes_per_platform.values()) for platform in self.platforms: indices = self.cubes_per_platform[platform] rng.shuffle(indices) repeated_indices = np.tile(indices, (max_len // len(indices) + 1))[:max_len] cubes_per_platform_per_epoch[platform] = repeated_indices # Create batches such that we return one platform per batch in cycle # Ignore the last batch if it is incomplete for i in range(0, max_len, self.batch_size): for platform in self.platforms: batch = cubes_per_platform_per_epoch[platform][i : i + self.batch_size] if len(batch) == self.batch_size: yield batch def __len__(self): return len(self.dataset.chips_path) // self.batch_size def batch_collate(batch): """Collate function for DataLoader. Merge the first two dimensions of the input tensors. """ d = defaultdict(list) for item in batch: d["pixels"].append(item["pixels"]) d["time"].append(item["time"]) d["latlon"].append(item["latlon"]) d["platform"].append(item["platform"]) return { "pixels": rearrange(d["pixels"], "b1 b2 c h w -> (b1 b2) c h w"), "time": rearrange(d["time"], "b1 b2 t -> (b1 b2) t"), "latlon": rearrange(d["latlon"], "b1 b2 ll -> (b1 b2) ll"), "platform": d["platform"], } class ClayDataModule(L.LightningDataModule): def __init__( # noqa: PLR0913 self, data_dir: str = "data", size: int = 224, metadata_path: str = "configs/metadata.yaml", platforms: list = [ "landsat-c2l1", "landsat-c2l2-sr", "linz", "naip", "sentinel-1-rtc", "sentinel-2-l2a", ], batch_size: int = 10, num_workers: int = 8, ): super().__init__() self.data_dir = data_dir self.size = size self.platforms = platforms self.metadata = Box(yaml.safe_load(open(metadata_path))) self.batch_size = batch_size self.num_workers = num_workers self.split_ratio = 0.8 def setup(self, stage: Literal["fit", "predict"] | None = None) -> None: # Get list of GeoTIFF filepaths from s3 bucket or data/ folder if self.data_dir.startswith("s3://"): dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.data_dir]) chips_path = list(dp.list_files_by_s3(masks="*.npz")) else: # if self.data_dir is a local data path chips_path = sorted(list(Path(self.data_dir).glob("**/*.npz"))) chips_platform = [chip.parent.parent.name for chip in chips_path] # chips_platform = [chip.parent.parent.name for chip in chips_path] print(f"Total number of chips: {len(chips_path)}") if stage == "fit": trn_paths, val_paths = train_test_split( chips_path, test_size=(1 - self.split_ratio), stratify=chips_platform, shuffle=True, ) self.trn_ds = EODataset( chips_path=trn_paths, size=self.size, platforms=self.platforms, metadata=self.metadata, ) self.trn_sampler = ClaySampler( dataset=self.trn_ds, platforms=self.platforms, batch_size=self.batch_size, ) self.val_ds = EODataset( chips_path=val_paths, size=self.size, platforms=self.platforms, metadata=self.metadata, ) self.val_sampler = ClaySampler( dataset=self.val_ds, platforms=self.platforms, batch_size=self.batch_size, ) elif stage == "predict": self.prd_ds = EODataset( chips_path=chips_path, platform=self.platform, metadata_path=self.metadata_path, ) def train_dataloader(self): return DataLoader( self.trn_ds, num_workers=self.num_workers, batch_sampler=self.trn_sampler, collate_fn=batch_collate, pin_memory=True, prefetch_factor=4, ) def val_dataloader(self): return DataLoader( self.val_ds, num_workers=self.num_workers, batch_sampler=self.val_sampler, collate_fn=batch_collate, pin_memory=True, prefetch_factor=4, ) def predict_dataloader(self): return DataLoader( dataset=self.prd_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, ) | |
./src/callbacks.py: from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks.finetuning import BaseFinetuning class ProgressiveResizing(Callback): def __init__(self): self.resize_schedule = { 0: {"batch_size": 4, "num_workers": 4, "size": 64}, 10: {"batch_size": 2, "num_workers": 2, "size": 128}, 20: {"batch_size": 1, "num_workers": 1, "size": 256}, } def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch in self.resize_schedule: params = self.resize_schedule[trainer.current_epoch] trainer.datamodule.size = params["size"] trainer.datamodule.batch_size = params["batch_size"] trainer.datamodule.num_workers = params["num_workers"] trainer.datamodule.setup(stage="fit") def on_validation_epoch_start(self, trainer, pl_module): if trainer.current_epoch in self.resize_schedule: params = self.resize_schedule[trainer.current_epoch] trainer.datamodule.size = params["size"] trainer.datamodule.batch_size = params["batch_size"] trainer.datamodule.num_workers = params["num_workers"] trainer.datamodule.setup(stage="validate") class LayerwiseFinetuning(BaseFinetuning): def __init__(self, phase, train_bn=True): """Initializes with phase & batch-norm information. Args: phase (List): Phases of fine-tuning the backbone network. train_bn (bool, optional): Trains just the batch-norm layers even when all the other layers of the network are freezed. Defaults to True. """ super().__init__() self.phase = phase self.train_bn = train_bn def freeze_before_training(self, pl_module): """Freezes the encoder before starting the training.""" self.freeze( modules=[ pl_module.model.encoder.patch_embedding, pl_module.model.encoder.transformer, ], train_bn=self.train_bn, ) def finetune_function(self, pl_module, epoch, optimizer): if epoch == self.phase: """Unfreezes the encoder for training.""" print(f"In Phase {self.phase}: Full throttle") self.unfreeze_and_add_param_group( modules=[ pl_module.model.encoder.patch_embedding, pl_module.model.encoder.transformer, ], optimizer=optimizer, train_bn=self.train_bn, ) params = list(pl_module.parameters()) active = list(filter(lambda p: p.requires_grad, params)) print(f"active: {len(active)}, all: {len(params)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment