Last active
January 14, 2023 20:08
-
-
Save darothen/52b7555f3f4d42bc5433f1ca75739e90 to your computer and use it in GitHub Desktop.
DataTree dask integration example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Test `DataTree` and `dask` Integration\n", | |
"\n", | |
"The purpose of this notebook is to provide a basic way to load a multi-group dataset referencing data through `dask` arrays, and to prove that the top-level integration of the `dask` collection API allows us to treat `DataTree`s themselves as collections. We base this code off of @jbusecke's [CMIP demo](https://github.com/jbusecke/presentation_ams_2023) from AMS 2023.\n", | |
"\n", | |
"You will need to install `xarray-datatree` directly from the branch at [PR#196](https://github.com/xarray-contrib/datatree/pull/196)." | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Setup\n", | |
"\n", | |
"Initialize a local `distributed` cluster." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from distributed import Client\n", | |
"client = Client()\n", | |
"client" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Key package imports; you may need to install some dependencies." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import xarray as xr\n", | |
"import numpy as np\n", | |
"\n", | |
"from xmip.preprocessing import combined_preprocessing\n", | |
"from xmip.utils import google_cmip_col" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load Data\n", | |
"\n", | |
"Here we use `intake-esm` indirectly through `xmip` to select a portion of CMIP model data to retrieve." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"col = google_cmip_col()\n", | |
"query = dict(\n", | |
" source_id = [\n", | |
" 'IPSL-CM6A-LR',\n", | |
" 'MPI-ESM1-2-LR',\n", | |
" 'GFDL-ESM4',\n", | |
" 'EC-Earth3',\n", | |
" 'CMCC-ESM2',\n", | |
" 'CESM2',\n", | |
" ],\n", | |
" experiment_id = ['historical','ssp126', 'ssp370', 'ssp245', 'ssp585'],\n", | |
" grid_label='gn',\n", | |
")\n", | |
"cat = col.search(\n", | |
" **query,\n", | |
" variable_id='tos',\n", | |
" member_id=['r1i1p1f1',],#'r2i1p1f1'\n", | |
" table_id='Omon'\n", | |
")\n", | |
"kwargs = dict(preprocess=combined_preprocessing, xarray_open_kwargs=dict(use_cftime=True), aggregate=False)\n", | |
"ddict = cat.to_dataset_dict(**kwargs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cat_area = col.search(\n", | |
" **query,\n", | |
" table_id='Ofx',\n", | |
" variable_id='areacello',\n", | |
")\n", | |
"ddict_area = cat_area.to_dataset_dict(**kwargs)" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Further post-process the data with `xmip`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from xmip.postprocessing import match_metrics\n", | |
"ddict_w_area = match_metrics(ddict, ddict_area, 'areacello', print_statistics=True) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from xmip.postprocessing import concat_members\n", | |
"\n", | |
"ddict_trimmed = {k:ds.sel(time=slice(None, '2100')) for k,ds in ddict_w_area.items()}\n", | |
"ddict_combined_members = concat_members(\n", | |
" ddict_w_area,\n", | |
" concat_kwargs = {'coords':'minimal', 'compat':'override', 'join':'override'}\n", | |
")" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Analysis\n", | |
"\n", | |
"Using the pre-processed data above, construct a new `DataTree` object to represent our CMIP subset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from datatree import DataTree\n", | |
"\n", | |
"# create a path: dataset dictionary, where the path is based on each datasets attributes\n", | |
"tree_dict = {f\"{ds.source_id}/{ds.experiment_id}/\":ds for ds in ddict_combined_members.values()}\n", | |
"\n", | |
"dt = DataTree.from_dict(tree_dict)\n", | |
"dt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dt.nbytes / 1e9 # size in GB" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Down-select a single member from each experiment." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dt_single_member = DataTree()\n", | |
"for model_name, model in dt.children.items():\n", | |
" member_id_values = []\n", | |
" for experiment_name, experiment in model.children.items():\n", | |
" ds = experiment.ds\n", | |
" member_id_values.append(set(ds.member_id.data)) \n", | |
" \n", | |
" # find the intersection of all values\n", | |
" # print(member_id_values)\n", | |
" full_members = set(member_id_values[0]).intersection(*member_id_values)\n", | |
" # sort and take the first one\n", | |
" pick_member = sorted(full_members)[0]\n", | |
" dt_single_member[model_name] = model.sel(member_id=pick_member)" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Compute weighted global mean sst" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# average temperature globally\n", | |
"def global_mean_sst(ds):\n", | |
" return ds.tos.weighted(ds.areacello.fillna(0)).mean(['x', 'y'])#.persist() \n", | |
"\n", | |
"timeseries = dt_single_member.map_over_subtree(global_mean_sst)\n", | |
"timeseries" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Notice in the printout above of our `DataTree` that our data are actually `dask` arrays, and have not been eagerly loaded into memory. We can see that more directly by pulling out the **tos** variable from a specific model run:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"key = '/IPSL-CM6A-LR/ssp585'\n", | |
"timeseries[key]['tos']" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Above, you should clearly see that our data is contained in a `dask` array!\n", | |
"\n", | |
"Compute timeseries anomalies, relative to 1950-1980." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_ref_value(ds):\n", | |
" return ds.sel(time=slice('1950','1980')).mean('time')\n", | |
"\n", | |
"anomaly = DataTree()\n", | |
"for model_name, model in timeseries.children.items():\n", | |
" # model-specific base period\n", | |
" base_period = get_ref_value(model[\"historical\"].ds)\n", | |
" anomaly[model_name] = model - base_period # subtree - Dataset" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Although we've done a computation on the data, it's still a `dask` array:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"anomaly[key]['tos']" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Eagerly compute the result of the above calculation, and return into a new object. You can monitor the `distributed` cluster we set up in the first cell of this notebook to see how it iterates and computes all the available `dask` graphs that it contains." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"anomaly_inmem = anomaly.compute()" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can prove to ourselves that the desired result has all been computed and is available in local memory:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"anomaly_inmem[key]['tos']" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "datatree_dev", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.8" | |
}, | |
"orig_nbformat": 4, | |
"vscode": { | |
"interpreter": { | |
"hash": "393ef64ddd11f947c5b4f8fab1f868967a5676a8bd8c5954286cd6bcdd1ba6fa" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment