Created
July 19, 2023 12:37
-
-
Save previtus/82b8d91c6f8db2817cde94d09e59976b to your computer and use it in GitHub Desktop.
Example of remote sensing dataloaders for pytorch lightning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pytorch_lightning as pl | |
from typing import Optional, Tuple, List | |
import kornia.augmentation as K | |
from torch.utils.data import DataLoader, WeightedRandomSampler | |
from starcop.data import dataset | |
import pandas as pd | |
from . import feature_extration | |
import rasterio.windows | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import logging | |
def tiled_dataframe(dataframe:pd.DataFrame, tile_size:Tuple[int, int], | |
overlap:Tuple[int, int], output_products:List[str], | |
num_workers:int=2) -> pd.DataFrame: | |
from georeader.slices import create_windows | |
dataframe_tiled_list = [] | |
for row in dataframe.reset_index().to_dict(orient="records"): | |
del row["window_row_off"] | |
del row["window_col_off"] | |
del row["window_width"] | |
del row["window_height"] | |
for w in create_windows((512, 512), | |
window_size=tile_size, overlap=overlap, include_incomplete=False): | |
row_copy = dict(row) | |
row_copy["window"] = w | |
dataframe_tiled_list.append(row_copy) | |
dataframe_tiled = pd.DataFrame(dataframe_tiled_list) | |
# Compute weighting score | |
dataset_labels = dataset.STARCOPDataset(dataframe_tiled, | |
input_products=[], | |
output_products=output_products, | |
weight_loss=None, | |
spatial_augmentations=None, | |
window_size_sample=None) | |
dl_labels = DataLoader(dataset_labels, batch_size=1, num_workers=num_workers, shuffle=False) | |
frac_positives = [] | |
for label in tqdm(dl_labels, total=len(dl_labels), desc="Computing label statistics"): | |
with torch.no_grad(): | |
frac_positives.append(torch.sum(label["output"]).item() / np.prod(tuple(label["output"].shape))) | |
dataframe_tiled["frac_positives"] = np.array(frac_positives) | |
dataframe_tiled["has_plume"] = dataframe_tiled["frac_positives"] > (10 / 64**2) | |
for attr_name in ["col_off", "row_off", "width", "height"]: | |
dataframe_tiled[f"window_{attr_name}"] = dataframe_tiled["window"].apply( | |
lambda x: None if x is None else getattr(x, attr_name)) | |
dataframe_tiled["id_original"] = dataframe_tiled["id"].copy() | |
dataframe_tiled["id"] = dataframe_tiled.apply( | |
lambda | |
row: f"{row['id']}_r{row.window_row_off}_c{row.window_col_off}_w{row.window_width}_h{row.window_height}", | |
axis=1) | |
dataframe_tiled = dataframe_tiled.set_index("id") | |
return dataframe_tiled | |
class Permian2019DataModule(pl.LightningDataModule): | |
def __init__(self, settings): | |
super().__init__() | |
self.settings = settings | |
self.products_plot = settings.products_plot | |
self.batch_size = self.settings.dataloader.batch_size | |
self.num_workers = self.settings.dataloader.num_workers | |
self.input_products = self.settings.dataset.input_products | |
self.output_products = self.settings.dataset.output_products | |
self.training_size = self.settings.dataset.training_size | |
self.training_size_overlap = self.settings.dataset.training_size_overlap | |
self.root_folder = self.settings.dataset.root_folder | |
self.train_csv = self.settings.dataset.train_csv | |
self.test_csv = "test.csv" | |
if self.settings.dataset.use_weight_loss: | |
self.weight_loss = self.settings.dataset.weight_loss | |
else: | |
self.weight_loss = None | |
self.weight_sampling = self.settings.dataset.weight_sampling | |
def setup(self, stage: Optional[str] = None) -> None: | |
"""Initialize the main ``Dataset`` objects. | |
This method is called once per GPU per run. | |
Args: | |
stage: stage to set up | |
""" | |
def load_dataframe(self, path) -> pd.DataFrame: | |
train_dataframe = pd.read_csv(path) | |
train_dataframe["window"] = train_dataframe.apply( | |
lambda row: rasterio.windows.Window(col_off=row.window_col_off, row_off=row.window_row_off, | |
width=row.window_width, height=row.window_height), | |
axis=1) | |
train_dataframe["folder"] = train_dataframe["id"].apply(lambda x: os.path.join(self.root_folder, x)) | |
train_dataframe = train_dataframe.set_index("id") | |
return train_dataframe | |
def prepare_data(self): | |
""" | |
Make sure that the dataset is downloaded. | |
This method is only called once per run. | |
""" | |
log = logging.getLogger(__name__) | |
if self.weight_loss is not None: | |
extra_types = ["input"] | |
weight_loss_list = [self.weight_loss] | |
else: | |
extra_types = [] | |
weight_loss_list = [] | |
if self.settings.model.model_mode == "segmentation_output": | |
model_output_type = "mask" | |
else: | |
model_output_type = "input" | |
self.train_augmentations = K.AugmentationSequential( | |
K.RandomRotation(p=0.5, degrees=90), | |
K.RandomHorizontalFlip(p=0.5), | |
K.RandomVerticalFlip(p=0.5), | |
keepdim=True, | |
data_keys=["input" , model_output_type] + extra_types, | |
) | |
# Feature extraction if needed | |
raw_bands_available = feature_extration.raw_bands_available() | |
self.features_extract = [f for f in self.input_products + self.output_products + weight_loss_list if | |
f not in raw_bands_available] | |
# Which products are needed as input | |
self.raw_bands = [f for f in self.input_products + self.output_products + weight_loss_list if | |
f in raw_bands_available] | |
train_dataset_path = os.path.join(self.root_folder, self.train_csv) | |
test_dataset_path = os.path.join(self.root_folder, self.test_csv) | |
products = list(self.raw_bands) | |
for f in self.features_extract: | |
products.extend(feature_extration.FEATURES[f]["inputs"]) | |
# Download the data to self.root_folder if needed | |
if not os.path.exists(train_dataset_path): | |
log.info(f"File for training dataset {train_dataset_path} not found we will download the data to {self.root_folder}") | |
from starcop.data import sampling_dataset | |
sampling_dataset.generate_train_data_permian_2019(self.root_folder, num_processes=self.num_workers, | |
products=products) | |
# Download the data to self.root_folder if needed | |
if not os.path.exists(test_dataset_path): | |
from starcop.data import sampling_dataset | |
log.info( | |
f"File for testing dataset {test_dataset_path} not found we will download the data to {self.root_folder}") | |
sampling_dataset.generate_test_data_permian_2019(self.root_folder, num_processes=self.num_workers, | |
products=products) | |
# Process train dataframe | |
self.train_dataframe_original = self.load_dataframe(train_dataset_path) | |
# Extract features if needed | |
if len(self.features_extract) > 0: | |
feature_extration.extract_features(self.features_extract, self.train_dataframe_original) | |
# slice train_dataframe in windows | |
if np.any(np.array(self.training_size) < np.array([512, 512])): | |
name_csv, ext = os.path.splitext(self.train_csv) | |
train_dataset_path_tiled = os.path.join(self.root_folder, f"{name_csv}_tiled_{self.training_size[0]}_{self.training_size[1]}{ext}") | |
if not os.path.exists(train_dataset_path_tiled): | |
log.info(f"Tiled dataset {train_dataset_path_tiled} not found. Generating") | |
train_dataframe = tiled_dataframe(self.train_dataframe_original, tile_size=self.training_size, overlap=self.training_size_overlap, | |
output_products=self.output_products, num_workers=self.num_workers) | |
train_dataframe[[c for c in train_dataframe.columns if c != "window"]].to_csv(train_dataset_path_tiled) | |
else: | |
log.info(f"Loading tiled dataset {train_dataset_path_tiled}") | |
train_dataframe = pd.read_csv(train_dataset_path_tiled) | |
train_dataframe["window"] = train_dataframe.apply( | |
lambda row: rasterio.windows.Window(col_off=row.window_col_off, row_off=row.window_row_off, | |
width=row.window_width, height=row.window_height), | |
axis=1) | |
train_dataframe["folder"] = train_dataframe["id_original"].apply(lambda x: os.path.join(self.root_folder, x)) | |
train_dataframe = train_dataframe.set_index("id") | |
else: | |
train_dataframe = self.train_dataframe_original | |
self.train_dataset = dataset.STARCOPDataset(train_dataframe, | |
input_products=self.input_products, | |
output_products=self.output_products, | |
weight_loss=self.weight_loss, | |
spatial_augmentations=self.train_augmentations, | |
window_size_sample=None) | |
self.train_dataset_plot = dataset.STARCOPDataset(train_dataframe, | |
input_products=self.input_products, | |
output_products=self.output_products, | |
weight_loss=self.weight_loss, | |
spatial_augmentations=None, | |
window_size_sample=None) | |
self.train_dataset_non_tiled = dataset.STARCOPDataset(self.train_dataframe_original, | |
input_products=self.input_products, | |
output_products=self.output_products, | |
weight_loss=self.weight_loss, | |
spatial_augmentations=None, | |
window_size_sample=None) | |
# Process test dataframe | |
test_dataframe = self.load_dataframe(test_dataset_path) | |
test_dataframe = test_dataframe.sort_values(["has_plume","qplume"],ascending=False) | |
if len(self.features_extract) > 0: | |
feature_extration.extract_features(self.features_extract, test_dataframe) | |
self.test_dataset = dataset.STARCOPDataset(test_dataframe, | |
input_products=self.input_products, | |
weight_loss=self.weight_loss, | |
output_products=self.output_products) | |
self.test_dataset_plot = dataset.STARCOPDataset(test_dataframe, | |
input_products=self.input_products, | |
weight_loss=self.weight_loss, | |
output_products=self.output_products) | |
if "rgb_aviris" in self.products_plot and not all(b in self.input_products for b in ["TOA_AVIRIS_640nm", "TOA_AVIRIS_550nm", "TOA_AVIRIS_460nm"]): | |
self.train_dataset_plot.add_rgb_aviris = True | |
self.test_dataset_plot.add_rgb_aviris = True | |
if "mag1c" in self.products_plot and "mag1c" not in self.input_products: | |
self.train_dataset_plot.add_extra_products(["mag1c"]) | |
self.test_dataset_plot.add_extra_products(["mag1c"]) | |
self.val_dataset = self.test_dataset | |
log.info("Data module ready") | |
log.info(f"Input products: {self.input_products} Output products: {self.output_products} Weight loss: {self.weight_loss}") | |
log.info(f"Train dataset {len(self.train_dataset)} chipsize: {self.training_size}") | |
log.info(f"Val dataset {len(self.val_dataset)}") | |
log.info(f"Test dataset {len(self.test_dataset)}") | |
def train_plot_dataloader(self, batch_size:int,num_workers:int=0): | |
if self.weight_sampling: | |
# Set weight per sample | |
train_dataframe = add_sample_weight(self.train_dataset_plot.dataframe) | |
weight_random_sampler = WeightedRandomSampler(train_dataframe["sample_weight"].values, | |
num_samples=len(self.train_dataset_plot), | |
replacement=True) # Must be true otherwise we should lower num_samples | |
shuffle = False | |
else: | |
weight_random_sampler = None | |
shuffle = True | |
return DataLoader(self.train_dataset_plot, batch_size=batch_size, | |
num_workers=num_workers, sampler=weight_random_sampler, | |
shuffle=shuffle) | |
def test_plot_dataloader(self, batch_size:int,num_workers:int=0): | |
return DataLoader(self.test_dataset_plot, batch_size=batch_size, | |
shuffle=False, num_workers=num_workers) | |
def train_dataloader(self, num_workers:Optional[int]=None, batch_size:Optional[int]=None): | |
"""Initializes and returns the training dataloader""" | |
batch_size = batch_size or self.batch_size | |
num_workers = num_workers or self.num_workers | |
if self.weight_sampling: | |
# Set weight per sample | |
train_dataframe = add_sample_weight(self.train_dataset.dataframe) | |
weight_random_sampler = WeightedRandomSampler(train_dataframe["sample_weight"].values, | |
num_samples=len(self.train_dataset), | |
replacement=True) # Must be true otherwise we should lower num_samples | |
shuffle=False | |
else: | |
weight_random_sampler = None | |
shuffle=True | |
return DataLoader(self.train_dataset, batch_size=batch_size, | |
num_workers=num_workers, sampler=weight_random_sampler, | |
shuffle=shuffle) | |
def val_dataloader(self, num_workers:Optional[int]=None, batch_size:Optional[int]=None): | |
"""Initializes and returns the validation dataloader""" | |
num_workers = num_workers or self.num_workers | |
batch_size = batch_size or self.batch_size | |
return DataLoader(self.val_dataset, batch_size=batch_size, | |
shuffle=False, num_workers=num_workers) | |
def test_dataloader(self, num_workers:Optional[int]=None, batch_size:Optional[int]=None): | |
"""Initializes and returns the test dataloader""" | |
num_workers = num_workers or self.num_workers | |
batch_size = batch_size or self.batch_size | |
return DataLoader(self.test_dataset, batch_size=batch_size, | |
shuffle=False, num_workers=num_workers) | |
def add_sample_weight(dataframe:pd.DataFrame) -> pd.DataFrame: | |
plume_fraction = np.sum(dataframe["has_plume"]) / dataframe.shape[0] | |
plume_weight = 1 / plume_fraction | |
non_plume_weight = 1 / (1 - plume_fraction) | |
dataframe["sample_weight"] = dataframe["has_plume"].apply( | |
lambda x: plume_weight if x else non_plume_weight) | |
return dataframe | |
# easy_train_dataset = data_module.train_dataset.dataframe[(data_module.train_dataset.dataframe.qplume >= 1000) | ~data_module.train_dataset.dataframe.has_plume].copy() | |
# frac_has_plume = easy_train_dataset.has_plume.sum() / easy_train_dataset.shape[0] + .03 | |
# easy_train_dataset_select = easy_train_dataset.has_plume | (np.random.rand(easy_train_dataset.shape[0]) <= frac_has_plume) | |
# easy_train_dataset = easy_train_dataset[easy_train_dataset_select].copy() | |
# easy_train_dataset.groupby("has_plume")[["name"]].count() | |
# easy_train_dataset[[c for c in easy_train_dataset.columns if c != "window"]].to_csv("/AVIRISNG/Permian2019/train_easy.csv") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Dict, Optional, Tuple, Any | |
import pandas as pd | |
from torch.utils.data import Dataset | |
import os | |
import rasterio | |
import torch | |
import rasterio.windows | |
import numpy as np | |
class STARCOPDataset(Dataset): | |
def __init__(self, dataframe:pd.DataFrame, | |
input_products:List[str], | |
output_products: List[str], | |
weight_loss:Optional[str]=None, | |
spatial_augmentations=None, | |
extra_products:Optional[List[str]]=None, | |
window_size_sample:Optional[Tuple[int, int]]=None): | |
self.dataframe = dataframe | |
assert "folder" in self.dataframe.columns, "folder not in columns of dataframe" | |
self.input_products = input_products | |
self.output_products = output_products | |
if "window" not in self.dataframe: | |
self.dataframe["window"] = None | |
self.weight_loss = weight_loss | |
self.spatial_augmentations = spatial_augmentations | |
self.window_size_sample = window_size_sample | |
self.extra_products = [] if extra_products is None else extra_products | |
self.add_rgb_aviris = False | |
def add_extra_products(self, products_add: List[str]): | |
p_add = [p for p in products_add if p not in self.extra_products and p not in self.input_products] | |
self.extra_products.extend(p_add) | |
def __len__(self): | |
return self.dataframe.shape[0] | |
def __getitem__(self, idx: int): | |
data_iter = self.dataframe.iloc[idx] | |
product_folder = data_iter.folder | |
window = data_iter.window | |
if self.window_size_sample is not None: | |
# random sample window to read from window | |
if window is None: | |
row_off = np.random.randint(0, 512 - self.window_size_sample[0]) | |
col_off = np.random.randint(0, 512 - self.window_size_sample[1]) | |
window = rasterio.windows.Window(row_off=row_off, col_off=col_off, width=self.window_size_sample[1], | |
height=self.window_size_sample[0]) | |
else: | |
row_off = window.row_off + np.random.randint(0, window.height - self.window_size_sample[0]) | |
col_off = window.col_off + np.random.randint(0, window.width - self.window_size_sample[1]) | |
window = rasterio.windows.Window(row_off=row_off, col_off=col_off, width=self.window_size_sample[1], | |
height=self.window_size_sample[0]) | |
out_dict = {} | |
names_outputs = ["input", "output"] | |
output_products = [self.input_products, self.output_products] | |
if self.weight_loss is not None: | |
names_outputs.append("weight_loss") | |
output_products.append([self.weight_loss]) | |
for io_name, products in zip(names_outputs, output_products): | |
tensors = [] | |
for key_name in products: | |
path = os.path.join(product_folder, f"{key_name}.tif") | |
with rasterio.open(path) as src: | |
tensors.append(torch.from_numpy(src.read(window=window))) | |
if len(tensors) > 1: | |
out_dict[io_name] = torch.cat(tensors, dim=0).float() | |
elif len(tensors) == 1: | |
out_dict[io_name] = tensors[0].float() | |
# Add extra products to the dict. This is useful for plotting to include mag1c in WV3 models | |
for key_name in self.extra_products: | |
path = os.path.join(product_folder, f"{key_name}.tif") | |
with rasterio.open(path) as src: | |
out_dict[key_name] = torch.from_numpy(src.read(window=window)) | |
names_outputs.append(key_name) | |
if self.add_rgb_aviris: | |
rgb_aviris = [] | |
for key_name in ["TOA_AVIRIS_640nm", "TOA_AVIRIS_550nm", "TOA_AVIRIS_460nm"]: | |
path = os.path.join(product_folder, f"{key_name}.tif") | |
with rasterio.open(path) as src: | |
rgb_aviris.append(torch.from_numpy(src.read(window=window))) | |
out_dict["rgb_aviris"] = torch.cat(rgb_aviris, dim=0).float() / 50. | |
if self.spatial_augmentations is not None: | |
out_list = [out_dict[k] for k in names_outputs] | |
out_list_aug = self.spatial_augmentations(*out_list) | |
out_dict = {k:v for k,v in zip(names_outputs, out_list_aug)} | |
# Add id and has_plume | |
out_dict["id"] = str(data_iter.name) | |
out_dict["has_plume"] = int(data_iter.has_plume) | |
return out_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment