Skip to content

Instantly share code, notes, and snippets.

@tommylees112
Created July 29, 2021 10:44
Show Gist options
  • Save tommylees112/c478f6b645f92409623b9609254d168d to your computer and use it in GitHub Desktop.
Save tommylees112/c478f6b645f92409623b9609254d168d to your computer and use it in GitHub Desktop.
Save load pandas / xarray objects
frmo typing import Dict, Union
import xarray as xr
import pandas as pd
from pathlib import Path
def save_scaler(scaler: Dict[str, Union[xr.Dataset, pd.DataFrame]], run_dir: Path) -> None:
"""Save scaler to disk as separate netcdf files"""
scaler_dir = run_dir / "train_data"
for k, v in scaler.items():
if isinstance(v, xr.Dataset) or isinstance(v, xr.DataArray):
v.to_netcdf(scaler_dir / f"{k.lower().replace(' ', '_')}.nc")
if isinstance(v, pd.DataFrame) or isinstance(v, pd.Series):
v.to_csv(scaler_dir / f"{k.lower().replace(' ', '_')}.csv")
def load_scaler(run_dir: Path) -> Dict[str, Union[xr.Dataset, pd.DataFrame]]:
scaler_dir = run_dir / "train_data"
scaler = {}
scaler_netcdf_paths = [p for p in scaler_dir.glob("*.nc") if any([test in p.name for test in ["scale", "center", "stds", "means"]])]
scaler_netcdf_keys = [p.name.split(".")[0] for p in scaler_netcdf_paths]
scaler_csv_paths = [p for p in scaler_dir.glob("*.csv") if any([test in p.name for test in ["scale", "center", "stds", "means"]])]
scaler_csv_keys = [p.name.split(".")[0] for p in scaler_csv_paths]
# load netcdf keys
for k, p in zip(scaler_netcdf_keys, scaler_netcdf_paths):
scaler[k] = xr.open_dataset(p)
# load pandas series
for k, p in zip(scaler_csv_keys, scaler_csv_paths):
scaler[k] = pd.read_csv(p, index_col=0)
return scaler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment