Last active
January 16, 2024 15:44
-
-
Save nilsleh/a38b3c681eb341ad79f2934ffeaab5aa to your computer and use it in GitHub Desktop.
Ocean Bench Lightning Datamodule
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Ocean Bench Datamodules.""" | |
import itertools | |
import os | |
from collections import namedtuple | |
from typing import Any | |
import hydra | |
import numpy as np | |
import ocn_tools._src.geoprocessing.gridding as obgrid | |
import pandas as pd | |
import torch | |
import xarray as xr | |
import xrpatcher | |
from oceanbench._src.utils.hydra import pipe | |
from omegaconf import DictConfig | |
from torch import Tensor | |
from torch.utils.data import DataLoader, Dataset, Subset, default_collate | |
from torchgeo.datamodules import NonGeoDataModule | |
def get_cfg(cfg_path) -> Any: | |
"""Loads and returns the configuration from a given path. | |
Args: | |
cfg_path: The path to the configuration file. | |
Returns: | |
The loaded configuration. | |
""" | |
with hydra.initialize("../../../config", version_base="1.3"): | |
cfg = hydra.compose(cfg_path) | |
base_dir = cfg.task.data.base_dir | |
cfg.task.data.base_dir = base_dir.replace("../", "home/user/") | |
cfg = hydra.compose(cfg_path).task.outputs | |
print(cfg) | |
return hydra.utils.call(cfg) | |
def norm_stats(ds: xr.Dataset) -> tuple[float, float]: | |
"""Compute normalization statistics from dataset (mean and std). | |
Args: | |
ds: Dataset | |
""" | |
return ds.da.sel(variable="ssh").pipe( | |
lambda da: (da.mean().item(), da.std().item()) | |
) | |
def patcher_from_osse_task( | |
task: DictConfig, | |
patcher_kw: dict[str, Any], | |
ref_var: str = "ssh", | |
split: str = "trainval", | |
): | |
"""Creates a patcher from an OSSE task. | |
Args: | |
task: The OSSE task | |
patcher_kw: The patcher keyword arguments. | |
ref_var: The reference variable. Defaults to 'ssh'. | |
split: The split type. Defaults to 'trainval'. | |
Returns: | |
xrpatcher.XRDAPatcher: The created patcher. | |
""" | |
default_domain_limits = dict( | |
time=slice(*task.splits[split]), | |
lat=slice(*task.domain.lat), | |
lon=slice(*task.domain.lon), | |
) | |
domain_limits = {**default_domain_limits, **patcher_kw.get("domain_limits", {})} | |
task_data = {k: v().sel(domain_limits) for k, v in task.data.items()} | |
da = xr.Dataset( | |
{ | |
k: v.assign_coords(task_data[ref_var].coords) if k != ref_var else v | |
for k, v in task_data.items() | |
} | |
).to_array() | |
return xrpatcher.XRDAPatcher(da, **patcher_kw) | |
def patcher_from_ose_task( | |
task: DictConfig, | |
tgt_grid_resolution: dict[str, Any], | |
patcher_kw: dict[str, Any], | |
ref_var: str = "ssh", | |
split: str = "train", | |
) -> xrpatcher.XRDAPatcher: | |
"""Creates a patcher from an OSE task. | |
Args: | |
task: The OSE task. | |
tgt_grid_resolution: The target grid resolution. | |
patcher_kw: The patcher keyword arguments. | |
ref_var: The reference variable. Defaults to 'ssh'. | |
split: The split type. Defaults to 'train'. | |
Returns: | |
xrpatcher.XRDAPatcher: The created patcher. | |
""" | |
default_domain_limits = dict( | |
time=task.splits[split], | |
lat=task.domain.lat, | |
lon=task.domain.lon, | |
) | |
domain_limits = {**default_domain_limits, **patcher_kw.get("domain_limits", {})} | |
select = lambda da: ( | |
da.sel(time=slice(*domain_limits["time"])) | |
.where(lambda da: da.lat > domain_limits["lat"][0], drop=True) | |
.where(lambda da: da.lon > domain_limits["lon"][0], drop=True) | |
.where(lambda da: da.lat < domain_limits["lat"][1], drop=True) | |
.where(lambda da: da.lon < domain_limits["lon"][1], drop=True) | |
) | |
tgt_grid = xr.Dataset( | |
coords=dict( | |
lat=np.arange(*domain_limits["lat"], tgt_grid_resolution["lat"]), | |
lon=np.arange(*domain_limits["lon"], tgt_grid_resolution["lon"]), | |
time=pd.date_range( | |
*domain_limits["time"], freq=tgt_grid_resolution["time"] | |
), | |
) | |
) | |
data = dict( | |
train=xr.combine_nested( | |
[v().pipe(select) for k, v in task.data["train"].items()], concat_dim="time" | |
), | |
test=xr.combine_nested( | |
[v().pipe(select) for k, v in task.data["test"].items()], concat_dim="time" | |
), | |
) | |
da = xr.Dataset( | |
{ | |
k: obgrid.coord_based_to_grid(v.to_dataset(name="ssh"), tgt_grid).ssh | |
for k, v in data.items() | |
} | |
).to_array() | |
return xrpatcher.XRDAPatcher(da, **patcher_kw) | |
class XrTorchDataset(Dataset): | |
"""Dataset for Xarray Datasets with XR Patcher.""" | |
def __init__(self, patcher: xrpatcher.XRDAPatcher, item_postpro=None): | |
"""Initialize a new instance of XrTorchDataset. | |
Args: | |
patcher: XR Patcher | |
item_postpro: Postprocessing function for items | |
""" | |
self.patcher = patcher | |
self.postpro = item_postpro | |
def __getitem__(self, idx) -> dict[str, Tensor]: | |
"""Get item at index `idx`. | |
Args: | |
idx: Index | |
Returns: | |
Item at index `idx` | |
""" | |
item = self.patcher[idx].load().values | |
if self.postpro: | |
item = self.postpro(item) | |
return {"input": item[0, ...], "target": item[1, ...]} | |
def reconstruct_from_batches(self, batches, **rec_kws): | |
"""Reconstruct dataset from batches. | |
Args: | |
batches: List of batches | |
""" | |
return self.patcher.reconstruct([*itertools.chain(*batches)], **rec_kws) | |
def __len__(self): | |
"""Get length of dataset.""" | |
return len(self.patcher) | |
class OceanBenchDataModule(NonGeoDataModule): | |
"""Ocean Bench DataModule.""" | |
valid_tasks = ["osse_gf_nadir", "osse_gf_nadirswot", "osse_gf_nadir_sst", "ose_gf"] | |
def __init__( | |
self, | |
task_name: str, | |
patcher_kw: dict[str, Any], | |
ref_var: str = "ssh", | |
batch_size: int = 32, | |
num_workers=0, | |
**kwargs, | |
): | |
"""Initialize a new instance of OceanBenchDataModule. | |
Args: | |
task: Task name, one of `valid_tasks` | |
patcher_kw: Keyword arguments for xrpatcher.XRDAPatcher | |
ref_var: Reference variable for patcher | |
batch_size: Batch size | |
num_workers: Number of workers | |
""" | |
super().__init__(XrTorchDataset, batch_size, num_workers, **kwargs) | |
assert task_name in self.valid_tasks, f"Task must be one of {self.valid_tasks}" | |
self.task_name = task_name | |
self.task_cfg = get_cfg(f"task/{task_name}/task") | |
self.ref_var = ref_var | |
self.patcher_kw = patcher_kw | |
# collate function for tensors | |
self.collate_fn = default_collate | |
if "osse" in self.task_name: | |
self.patcher_task_fn = patcher_from_osse_task | |
else: | |
self.patcher_task_fn = patcher_from_ose_task | |
self.train_patcher = self.patcher_task_fn( | |
self.task_cfg, self.patcher_kw, ref_var=self.ref_var, split="trainval" | |
) | |
self.test_patcher = self.patcher_task_fn( | |
self.task_cfg, self.patcher_kw, ref_var=self.ref_var, split="test" | |
) | |
self.mean, self.std = norm_stats(self.train_patcher) | |
self.item_postpro = lambda item: pipe( | |
item, | |
[ | |
lambda i: (i - self.mean) / self.std, | |
lambda i: i.astype(np.float32), | |
], | |
) | |
def setup(self, stage: str) -> None: | |
"""Set up datasets. | |
Called at the beginning of fit, validate, test, or predict. During distributed | |
training, this method is called from every process across all the nodes. Setting | |
state here is recommended. | |
Args: | |
stage: Either 'fit', 'validate', 'test', or 'predict'. | |
""" | |
if stage in ["fit", "validate"]: | |
train_dataset = XrTorchDataset( | |
self.train_patcher, item_postpro=self.item_postpro | |
) | |
# create train and validation split randomly by index | |
total_length = len(train_dataset) | |
train_length = int(total_length * 0.8) | |
val_length = total_length - train_length | |
train_indices, val_indices = torch.utils.data.random_split( | |
range(len(train_dataset)), | |
[train_length, val_length], | |
generator=torch.Generator().manual_seed(42), | |
) | |
self.train_dataset = Subset(train_dataset, train_indices) | |
self.val_dataset = Subset(train_dataset, val_indices) | |
if stage in ["test"]: | |
self.test_dataset = XrTorchDataset( | |
self.test_patcher, item_postpro=self.item_postpro | |
) | |
dm = OceanBenchDataModule( | |
task_name="osse_gf_nadir", | |
patcher_kw=dict(patches={"time": 5}, strides={"time": 1}), | |
ref_var="ssh", | |
batch_size=32, | |
num_workers=0, | |
) | |
dm.setup(stage="fit") | |
train_loader = dm.train_dataloader() | |
batch = next(iter(train_loader)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment