Skip to content

Instantly share code, notes, and snippets.

@darothen
Last active January 14, 2023 20:08
Show Gist options
  • Save darothen/52b7555f3f4d42bc5433f1ca75739e90 to your computer and use it in GitHub Desktop.
Save darothen/52b7555f3f4d42bc5433f1ca75739e90 to your computer and use it in GitHub Desktop.
DataTree dask integration example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test `DataTree` and `dask` Integration\n",
"\n",
"The purpose of this notebook is to provide a basic way to load a multi-group dataset referencing data through `dask` arrays, and to prove that the top-level integration of the `dask` collection API allows us to treat `DataTree`s themselves as collections. We base this code off of @jbusecke's [CMIP demo](https://github.com/jbusecke/presentation_ams_2023) from AMS 2023.\n",
"\n",
"You will need to install `xarray-datatree` directly from the branch at [PR#196](https://github.com/xarray-contrib/datatree/pull/196)."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"Initialize a local `distributed` cluster."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from distributed import Client\n",
"client = Client()\n",
"client"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Key package imports; you may need to install some dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import xarray as xr\n",
"import numpy as np\n",
"\n",
"from xmip.preprocessing import combined_preprocessing\n",
"from xmip.utils import google_cmip_col"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Data\n",
"\n",
"Here we use `intake-esm` indirectly through `xmip` to select a portion of CMIP model data to retrieve."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"col = google_cmip_col()\n",
"query = dict(\n",
" source_id = [\n",
" 'IPSL-CM6A-LR',\n",
" 'MPI-ESM1-2-LR',\n",
" 'GFDL-ESM4',\n",
" 'EC-Earth3',\n",
" 'CMCC-ESM2',\n",
" 'CESM2',\n",
" ],\n",
" experiment_id = ['historical','ssp126', 'ssp370', 'ssp245', 'ssp585'],\n",
" grid_label='gn',\n",
")\n",
"cat = col.search(\n",
" **query,\n",
" variable_id='tos',\n",
" member_id=['r1i1p1f1',],#'r2i1p1f1'\n",
" table_id='Omon'\n",
")\n",
"kwargs = dict(preprocess=combined_preprocessing, xarray_open_kwargs=dict(use_cftime=True), aggregate=False)\n",
"ddict = cat.to_dataset_dict(**kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_area = col.search(\n",
" **query,\n",
" table_id='Ofx',\n",
" variable_id='areacello',\n",
")\n",
"ddict_area = cat_area.to_dataset_dict(**kwargs)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Further post-process the data with `xmip`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from xmip.postprocessing import match_metrics\n",
"ddict_w_area = match_metrics(ddict, ddict_area, 'areacello', print_statistics=True) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from xmip.postprocessing import concat_members\n",
"\n",
"ddict_trimmed = {k:ds.sel(time=slice(None, '2100')) for k,ds in ddict_w_area.items()}\n",
"ddict_combined_members = concat_members(\n",
" ddict_w_area,\n",
" concat_kwargs = {'coords':'minimal', 'compat':'override', 'join':'override'}\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Analysis\n",
"\n",
"Using the pre-processed data above, construct a new `DataTree` object to represent our CMIP subset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datatree import DataTree\n",
"\n",
"# create a path: dataset dictionary, where the path is based on each datasets attributes\n",
"tree_dict = {f\"{ds.source_id}/{ds.experiment_id}/\":ds for ds in ddict_combined_members.values()}\n",
"\n",
"dt = DataTree.from_dict(tree_dict)\n",
"dt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dt.nbytes / 1e9 # size in GB"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Down-select a single member from each experiment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dt_single_member = DataTree()\n",
"for model_name, model in dt.children.items():\n",
" member_id_values = []\n",
" for experiment_name, experiment in model.children.items():\n",
" ds = experiment.ds\n",
" member_id_values.append(set(ds.member_id.data)) \n",
" \n",
" # find the intersection of all values\n",
" # print(member_id_values)\n",
" full_members = set(member_id_values[0]).intersection(*member_id_values)\n",
" # sort and take the first one\n",
" pick_member = sorted(full_members)[0]\n",
" dt_single_member[model_name] = model.sel(member_id=pick_member)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute weighted global mean sst"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# average temperature globally\n",
"def global_mean_sst(ds):\n",
" return ds.tos.weighted(ds.areacello.fillna(0)).mean(['x', 'y'])#.persist() \n",
"\n",
"timeseries = dt_single_member.map_over_subtree(global_mean_sst)\n",
"timeseries"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice in the printout above of our `DataTree` that our data are actually `dask` arrays, and have not been eagerly loaded into memory. We can see that more directly by pulling out the **tos** variable from a specific model run:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"key = '/IPSL-CM6A-LR/ssp585'\n",
"timeseries[key]['tos']"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Above, you should clearly see that our data is contained in a `dask` array!\n",
"\n",
"Compute timeseries anomalies, relative to 1950-1980."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_ref_value(ds):\n",
" return ds.sel(time=slice('1950','1980')).mean('time')\n",
"\n",
"anomaly = DataTree()\n",
"for model_name, model in timeseries.children.items():\n",
" # model-specific base period\n",
" base_period = get_ref_value(model[\"historical\"].ds)\n",
" anomaly[model_name] = model - base_period # subtree - Dataset"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Although we've done a computation on the data, it's still a `dask` array:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"anomaly[key]['tos']"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Eagerly compute the result of the above calculation, and return into a new object. You can monitor the `distributed` cluster we set up in the first cell of this notebook to see how it iterates and computes all the available `dask` graphs that it contains."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"anomaly_inmem = anomaly.compute()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We can prove to ourselves that the desired result has all been computed and is available in local memory:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"anomaly_inmem[key]['tos']"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "datatree_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "393ef64ddd11f947c5b4f8fab1f868967a5676a8bd8c5954286cd6bcdd1ba6fa"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment