Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shoyer/5cfa4d5751e8a78a14af25f8442ad8d5 to your computer and use it in GitHub Desktop.
Save shoyer/5cfa4d5751e8a78a14af25f8442ad8d5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Demo of multi-dimensional groupby operations"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"import xarray as xr\n",
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1 dimension"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"a = xr.DataArray(np.random.RandomState(0).randn(3),\n",
" coords={'y': ('x', [0, 1, 1]),\n",
" 'z': ('x', ['a', 'a', 'b'])},\n",
" dims=['x'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (grouped_x_y: 3)>\n",
"array([ 1.76405235, 0.40015721, 0.97873798])\n",
"Coordinates:\n",
" * grouped_x_y (grouped_x_y) object (0, 0) (1, 1) (2, 1)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.groupby(['x', 'y']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 3)>\n",
"array([ 1.76405235, 0.40015721, 0.97873798])\n",
"Coordinates:\n",
" y (x) int64 0 1 1\n",
" * x (x) int64 0 1 2\n",
" z (x) <U1 'a' 'a' 'b'\n",
" grouped_y_z (x) object (0, 'a') (1, 'a') (1, 'b')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.groupby(['y', 'z']).apply(lambda x: x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2 dimensions"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"coords = {'a': ('x', [0, 0, 1, 1]), 'b': ('y', [0, 0, 1, 1])}\n",
"square = xr.DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=['x', 'y'])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 4, y: 4)>\n",
"array([[ 0, 1, 2, 3],\n",
" [ 4, 5, 6, 7],\n",
" [ 8, 9, 10, 11],\n",
" [12, 13, 14, 15]])\n",
"Coordinates:\n",
" a (x) int64 0 0 1 1\n",
" b (y) int64 0 0 1 1\n",
" * x (x) int64 0 1 2 3\n",
" * y (y) int64 0 1 2 3"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (a: 2, b: 2)>\n",
"array([[ 2.5, 4.5],\n",
" [ 10.5, 12.5]])\n",
"Coordinates:\n",
" * a (a) int64 0 1\n",
" * b (b) int64 0 1"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square.groupby(['a', 'b']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 4, y: 4)>\n",
"array([[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [ 12., 13., 14., 15.]])\n",
"Coordinates:\n",
" * x (x) int64 0 1 2 3\n",
" * y (y) int64 0 1 2 3"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square.groupby(['x', 'y']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 4, y: 4)>\n",
"array([[-2.5, -1.5, -2.5, -1.5],\n",
" [ 1.5, 2.5, 1.5, 2.5],\n",
" [-2.5, -1.5, -2.5, -1.5],\n",
" [ 1.5, 2.5, 1.5, 2.5]])\n",
"Coordinates:\n",
" a (x, y) int64 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1\n",
" b (x, y) int64 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1\n",
" grouped_a_b (x, y) object (0, 0) (0, 0) (0, 1) (0, 1) (0, 0) (0, 0) ...\n",
" * x (x) int64 0 1 2 3\n",
" * y (y) int64 0 1 2 3"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square.groupby(['a', 'b']).apply(lambda x: x - x.mean())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (grouped_a_x: 4)>\n",
"array([ 1.5, 5.5, 9.5, 13.5])\n",
"Coordinates:\n",
" * grouped_a_x (grouped_a_x) object (0, 0) (0, 1) (1, 2) (1, 3)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square.groupby(['a', 'x']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (a: 2, y: 4)>\n",
"array([[ 2., 3., 4., 5.],\n",
" [ 10., 11., 12., 13.]])\n",
"Coordinates:\n",
" * a (a) int64 0 1\n",
" * y (y) int64 0 1 2 3"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square.groupby(['a', 'y']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (b: 2, x: 4)>\n",
"array([[ 0.5, 4.5, 8.5, 12.5],\n",
" [ 2.5, 6.5, 10.5, 14.5]])\n",
"Coordinates:\n",
" * b (b) int64 0 1\n",
" * x (x) int64 0 1 2 3"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square.groupby(['x', 'b']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 4, y: 4)>\n",
"array([[-0.5, 0.5, -0.5, 0.5],\n",
" [-0.5, 0.5, -0.5, 0.5],\n",
" [-0.5, 0.5, -0.5, 0.5],\n",
" [-0.5, 0.5, -0.5, 0.5]])\n",
"Coordinates:\n",
" a (x, y) int64 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1\n",
" b (x, y) int64 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1\n",
" grouped_x_b (x, y) object (0, 0) (0, 0) (1, 0) (1, 0) (0, 1) (0, 1) ...\n",
" * x (x) int64 0 1 2 3\n",
" * y (y) int64 0 1 2 3"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"square.groupby(['x', 'b']).apply(lambda x: x - x.mean())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Maybe add automatic unstacking to the iterator interface?"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[((0, 0), <xarray.DataArray (stacked_x_y: 4)>\n",
" array([0, 1, 4, 5])\n",
" Coordinates:\n",
" a (stacked_x_y) int64 0 0 0 0\n",
" b (stacked_x_y) int64 0 0 0 0\n",
" * stacked_x_y (stacked_x_y) object (0, 0) (0, 1) (1, 0) (1, 1)),\n",
" ((0, 1), <xarray.DataArray (stacked_x_y: 4)>\n",
" array([2, 3, 6, 7])\n",
" Coordinates:\n",
" a (stacked_x_y) int64 0 0 0 0\n",
" b (stacked_x_y) int64 1 1 1 1\n",
" * stacked_x_y (stacked_x_y) object (0, 2) (0, 3) (1, 2) (1, 3)),\n",
" ((1, 0), <xarray.DataArray (stacked_x_y: 4)>\n",
" array([ 8, 9, 12, 13])\n",
" Coordinates:\n",
" a (stacked_x_y) int64 1 1 1 1\n",
" b (stacked_x_y) int64 0 0 0 0\n",
" * stacked_x_y (stacked_x_y) object (2, 0) (2, 1) (3, 0) (3, 1)),\n",
" ((1, 1), <xarray.DataArray (stacked_x_y: 4)>\n",
" array([10, 11, 14, 15])\n",
" Coordinates:\n",
" a (stacked_x_y) int64 1 1 1 1\n",
" b (stacked_x_y) int64 1 1 1 1\n",
" * stacked_x_y (stacked_x_y) object (2, 2) (2, 3) (3, 2) (3, 3))]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(square.groupby(['a', 'b']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3 dimensions"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"b = xr.DataArray(np.random.RandomState(0).randn(2, 3, 4),\n",
" coords={'xy': (('x', 'y'), [['a', 'b', 'c'], ['b', 'c', 'c']])},\n",
" dims=['x', 'y', 'z'])"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 2, y: 3, z: 4)>\n",
"array([[[ 1.76405235, 0.40015721, 0.97873798, 2.2408932 ],\n",
" [ 1.86755799, -0.97727788, 0.95008842, -0.15135721],\n",
" [-0.10321885, 0.4105985 , 0.14404357, 1.45427351]],\n",
"\n",
" [[ 0.76103773, 0.12167502, 0.44386323, 0.33367433],\n",
" [ 1.49407907, -0.20515826, 0.3130677 , -0.85409574],\n",
" [-2.55298982, 0.6536186 , 0.8644362 , -0.74216502]]])\n",
"Coordinates:\n",
" xy (x, y) <U1 'a' 'b' 'c' 'b' 'c' 'c'\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2\n",
" * z (z) int64 0 1 2 3"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 2, y: 3)>\n",
"array([[ 5.38384074, 1.68901132, 1.90569673],\n",
" [ 1.6602503 , 0.74789277, -1.77710004]])\n",
"Coordinates:\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.groupby(['x', 'y']).sum()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 2, y: 3)>\n",
"array([[ 5.38384074, 1.68901132, 1.90569673],\n",
" [ 1.6602503 , 0.74789277, -1.77710004]])\n",
"Coordinates:\n",
" xy (x, y) <U1 'a' 'b' 'c' 'b' 'c' 'c'\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.sum('z')"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (z: 4, x: 2, y: 3)>\n",
"array([[[ 0.41809216, 1.44530516, -0.57964303],\n",
" [ 0.34597515, 1.30710588, -2.10871481]],\n",
"\n",
" [[-0.94580298, -1.39953071, -0.06582568],\n",
" [-0.29338756, -0.39213146, 1.09789361]],\n",
"\n",
" [[-0.3672222 , 0.52783559, -0.33238061],\n",
" [ 0.02880066, 0.12609451, 1.30871121]],\n",
"\n",
" [[ 0.89493301, -0.57361004, 0.97784932],\n",
" [-0.08138825, -1.04106893, -0.29789001]]])\n",
"Coordinates:\n",
" * z (z) int64 0 1 2 3\n",
" xy (x, y) <U1 'a' 'b' 'c' 'b' 'c' 'c'\n",
" grouped_x_y (x, y) object (0, 0) (0, 1) (0, 2) (1, 0) (1, 1) (1, 2)\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.groupby(['x', 'y']).apply(lambda x: x - x.mean())"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 2, y: 3, z: 4)>\n",
"array([[[ 0.41809216, -0.94580298, -0.3672222 , 0.89493301],\n",
" [ 1.44530516, -1.39953071, 0.52783559, -0.57361004],\n",
" [-0.57964303, -0.06582568, -0.33238061, 0.97784932]],\n",
"\n",
" [[ 0.34597515, -0.29338756, 0.02880066, -0.08138825],\n",
" [ 1.30710588, -0.39213146, 0.12609451, -1.04106893],\n",
" [-2.10871481, 1.09789361, 1.30871121, -0.29789001]]])\n",
"Coordinates:\n",
" xy (x, y) <U1 'a' 'b' 'c' 'b' 'c' 'c'\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2\n",
" * z (z) int64 0 1 2 3"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b - b.mean(['z'])"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (xy: 3, x: 2)>\n",
"array([[ 1.34596018, nan],\n",
" [ 0.42225283, 0.41506258],\n",
" [ 0.47642418, -0.12865091]])\n",
"Coordinates:\n",
" * xy (xy) object 'a' 'b' 'c'\n",
" * x (x) int64 0 1"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.groupby(['x', 'xy']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (xy: 3)>\n",
"array([ 1.34596018, 0.4186577 , 0.07304079])\n",
"Coordinates:\n",
" * xy (xy) object 'a' 'b' 'c'"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.groupby('xy').mean()"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (x: 2, y: 3)>\n",
"array([[ 1.34596018, 0.42225283, 0.47642418],\n",
" [ 0.41506258, 0.18697319, -0.44427501]])\n",
"Coordinates:\n",
" xy (x, y) <U1 'a' 'b' 'c' 'b' 'c' 'c'\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.groupby('xy').mean('z')"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray (z: 4, xy: 3)>\n",
"array([[ 1.76405235, 1.31429786, -0.38737653],\n",
" [ 0.40015721, -0.42780143, 0.28635294],\n",
" [ 0.97873798, 0.69697583, 0.44051582],\n",
" [ 2.2408932 , 0.09115856, -0.04732908]])\n",
"Coordinates:\n",
" * z (z) int64 0 1 2 3\n",
" * xy (xy) object 'a' 'b' 'c'"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.groupby('xy').mean('stacked_x_y')"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[('a', <xarray.DataArray (z: 4, stacked_x_y: 1)>\n",
" array([[ 1.76405235],\n",
" [ 0.40015721],\n",
" [ 0.97873798],\n",
" [ 2.2408932 ]])\n",
" Coordinates:\n",
" xy (stacked_x_y) <U1 'a'\n",
" * z (z) int64 0 1 2 3\n",
" * stacked_x_y (stacked_x_y) object (0, 0)),\n",
" ('b', <xarray.DataArray (z: 4, stacked_x_y: 2)>\n",
" array([[ 1.86755799, 0.76103773],\n",
" [-0.97727788, 0.12167502],\n",
" [ 0.95008842, 0.44386323],\n",
" [-0.15135721, 0.33367433]])\n",
" Coordinates:\n",
" xy (stacked_x_y) <U1 'b' 'b'\n",
" * z (z) int64 0 1 2 3\n",
" * stacked_x_y (stacked_x_y) object (0, 1) (1, 0)),\n",
" ('c', <xarray.DataArray (z: 4, stacked_x_y: 3)>\n",
" array([[-0.10321885, 1.49407907, -2.55298982],\n",
" [ 0.4105985 , -0.20515826, 0.6536186 ],\n",
" [ 0.14404357, 0.3130677 , 0.8644362 ],\n",
" [ 1.45427351, -0.85409574, -0.74216502]])\n",
" Coordinates:\n",
" xy (stacked_x_y) <U1 'c' 'c' 'c'\n",
" * z (z) int64 0 1 2 3\n",
" * stacked_x_y (stacked_x_y) object (0, 2) (1, 1) (1, 2))]"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(b.groupby('xy'))"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[('a', <xarray.DataArray (z: 4, x: 2, y: 3)>\n",
" array([[[ 1.76405235, nan, nan],\n",
" [ nan, nan, nan]],\n",
" \n",
" [[ 0.40015721, nan, nan],\n",
" [ nan, nan, nan]],\n",
" \n",
" [[ 0.97873798, nan, nan],\n",
" [ nan, nan, nan]],\n",
" \n",
" [[ 2.2408932 , nan, nan],\n",
" [ nan, nan, nan]]])\n",
" Coordinates:\n",
" xy (x, y) object 'a' nan nan nan nan nan\n",
" * z (z) int64 0 1 2 3\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2), ('b', <xarray.DataArray (z: 4, x: 2, y: 3)>\n",
" array([[[ nan, 1.86755799, nan],\n",
" [ 0.76103773, nan, nan]],\n",
" \n",
" [[ nan, -0.97727788, nan],\n",
" [ 0.12167502, nan, nan]],\n",
" \n",
" [[ nan, 0.95008842, nan],\n",
" [ 0.44386323, nan, nan]],\n",
" \n",
" [[ nan, -0.15135721, nan],\n",
" [ 0.33367433, nan, nan]]])\n",
" Coordinates:\n",
" xy (x, y) object nan 'b' nan 'b' nan nan\n",
" * z (z) int64 0 1 2 3\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2), ('c', <xarray.DataArray (z: 4, x: 2, y: 3)>\n",
" array([[[ nan, nan, -0.10321885],\n",
" [ nan, 1.49407907, -2.55298982]],\n",
" \n",
" [[ nan, nan, 0.4105985 ],\n",
" [ nan, -0.20515826, 0.6536186 ]],\n",
" \n",
" [[ nan, nan, 0.14404357],\n",
" [ nan, 0.3130677 , 0.8644362 ]],\n",
" \n",
" [[ nan, nan, 1.45427351],\n",
" [ nan, -0.85409574, -0.74216502]]])\n",
" Coordinates:\n",
" xy (x, y) object nan nan 'c' nan 'c' 'c'\n",
" * z (z) int64 0 1 2 3\n",
" * x (x) int64 0 1\n",
" * y (y) int64 0 1 2)]"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[(k, v.unstack('stacked_x_y')) for k, v in b.groupby('xy')]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment