-
-
Save RichardScottOZ/89e054ca5c4c9fbbbf8d5d6499df856c to your computer and use it in GitHub Desktop.
vectorized `sklearn` with `xarray`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# vectorized `sklearn` with `xarray`\n", | |
"\n", | |
"run a `sklearn` classifier on a grid (longitude/X, latitude/Y, lead_time, ...) all at once\n", | |
"\n", | |
"might be slow due to `vectorize=True`, but the code is short\n", | |
"\n", | |
"inspired by and based on https://renkulab.io/gitlab/lluis.palma/s2s-ai-challenge-bsc/-/blob/submission-ML_models/notebooks/S2S_ML_models.ipynb\n", | |
"\n", | |
"answers also https://discourse.pangeo.io/t/vectorized-sklearn/1444" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## import" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import xarray as xr\n", | |
"xr.set_options(display_style='text')\n", | |
"\n", | |
"import numpy as np\n", | |
"\n", | |
"from sklearn.linear_model import LogisticRegression" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<pre><xarray.Dataset>\n", | |
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 20)\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * year (year) int64 2000 2001 2002 2003 2004 ... 2016 2017 2018 2019\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
"Data variables:\n", | |
" t2m (lead_time, year, week, X, Y) float64 0.885 0.61 ... 0.1928\n", | |
" tp (lead_time, year, week, X, Y) float64 0.0597 0.7052 ... 0.3623\n", | |
" msl (lead_time, year, week, X, Y) float64 0.5728 0.8126 ... 0.2536</pre>" | |
], | |
"text/plain": [ | |
"<xarray.Dataset>\n", | |
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 20)\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * year (year) int64 2000 2001 2002 2003 2004 ... 2016 2017 2018 2019\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
"Data variables:\n", | |
" t2m (lead_time, year, week, X, Y) float64 0.885 0.61 ... 0.1928\n", | |
" tp (lead_time, year, week, X, Y) float64 0.0597 0.7052 ... 0.3623\n", | |
" msl (lead_time, year, week, X, Y) float64 0.5728 0.8126 ... 0.2536" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# synethetic data: competition on 5x5 grid \n", | |
"# raw forecasts\n", | |
"X_train = xr.DataArray(np.random.rand(2,20,53,5,5,3),\n", | |
" dims=['lead_time','year','week','X','Y','variable'],\n", | |
" coords={'lead_time':[1,2],'year':range(2000,2020),'week':range(53), 'X':range(5), \"Y\":range(5), \"variable\":['t2m','tp','msl']}\n", | |
" ).to_dataset(dim='variable')\n", | |
"X_train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<pre><xarray.Dataset>\n", | |
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 2)\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * year (year) int64 2018 2019\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
"Data variables:\n", | |
" t2m (lead_time, year, week, X, Y) float64 0.8516 0.4321 ... 0.1928\n", | |
" tp (lead_time, year, week, X, Y) float64 0.9754 0.6478 ... 0.3623\n", | |
" msl (lead_time, year, week, X, Y) float64 0.9741 0.05569 ... 0.2536</pre>" | |
], | |
"text/plain": [ | |
"<xarray.Dataset>\n", | |
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 2)\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * year (year) int64 2018 2019\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
"Data variables:\n", | |
" t2m (lead_time, year, week, X, Y) float64 0.8516 0.4321 ... 0.1928\n", | |
" tp (lead_time, year, week, X, Y) float64 0.9754 0.6478 ... 0.3623\n", | |
" msl (lead_time, year, week, X, Y) float64 0.9741 0.05569 ... 0.2536" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test = X_train.isel(year=[-2,-1])\n", | |
"X_train = X_train.isel(year=slice(None,-2))\n", | |
"X_test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<pre><xarray.Dataset>\n", | |
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 18)\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * year (year) int64 2000 2001 2002 2003 2004 ... 2014 2015 2016 2017\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
"Data variables:\n", | |
" t2m (lead_time, year, week, X, Y) float64 2.0 1.0 0.0 ... 0.0 1.0 2.0\n", | |
" tp (lead_time, year, week, X, Y) float64 0.0 2.0 0.0 ... 2.0 0.0 2.0\n", | |
" msl (lead_time, year, week, X, Y) float64 1.0 2.0 1.0 ... 2.0 1.0 1.0</pre>" | |
], | |
"text/plain": [ | |
"<xarray.Dataset>\n", | |
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 18)\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * year (year) int64 2000 2001 2002 2003 2004 ... 2014 2015 2016 2017\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
"Data variables:\n", | |
" t2m (lead_time, year, week, X, Y) float64 2.0 1.0 0.0 ... 0.0 1.0 2.0\n", | |
" tp (lead_time, year, week, X, Y) float64 0.0 2.0 0.0 ... 2.0 0.0 2.0\n", | |
" msl (lead_time, year, week, X, Y) float64 1.0 2.0 1.0 ... 2.0 1.0 1.0" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# categorized observations\n", | |
"y_train = xr.concat([\n", | |
" 0*xr.ones_like(X_train).where(X_train < 1/3, other=0),\n", | |
" 1*xr.ones_like(X_train).where((X_train > 1/3) & (X_train < 2/3), other=0),\n", | |
" 2*xr.ones_like(X_train).where(X_train > 2/3, other=0)\n", | |
"],'category').sum('category')\n", | |
"y_train" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## config" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sample_dims = ['year','week'] # dimensions used as samples\n", | |
"features = ['t2m','tp','msl'] # variables used as features\n", | |
"target_var = 't2m' # var to predict\n", | |
"\n", | |
"# sklearn method\n", | |
"clf = LogisticRegression(penalty='l2',\n", | |
" solver='liblinear',\n", | |
" random_state=0,\n", | |
" multi_class='auto')\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def atomic_function_training_LR(X_train, y_train, clf):\n", | |
" feature_size=X_train.shape[-1]\n", | |
" sample_size=np.prod(X_train.shape[:-1])\n", | |
" # ensure samples are first dimensions\n", | |
" X_train = X_train.reshape(sample_size, feature_size) # sample sizes, feature sizes\n", | |
" y_train = y_train.reshape(sample_size)\n", | |
" try:\n", | |
" clf = clf.fit(X_train, y_train)\n", | |
" return clf\n", | |
" except:\n", | |
" return None" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 93 ms, sys: 2.35 ms, total: 95.4 ms\n", | |
"Wall time: 103 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"all_classifiers = xr.apply_ufunc(\n", | |
" atomic_function_training_LR,\n", | |
" X_train[features].to_array().transpose(...,'variable'), # transpose variable last\n", | |
" y_train[target_var],\n", | |
" clf,\n", | |
" input_core_dims=[sample_dims+['variable'], sample_dims, []], # add variable if needed\n", | |
" vectorize=True,\n", | |
" dask='parallelized',\n", | |
" output_dtypes=[object])\n", | |
"all_classifiers = all_classifiers.compute()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## predict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.54 s, sys: 29.9 ms, total: 1.57 s\n", | |
"Wall time: 1.68 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre><xarray.DataArray (lead_time: 2, X: 5, Y: 5, week: 53, category: 3)>\n", | |
"array([[[[[2.84430615e-03, 2.48019553e-01, 7.49136141e-01],\n", | |
" [1.59098678e-02, 4.24464371e-01, 5.59625761e-01],\n", | |
" [6.43132743e-01, 3.47622040e-01, 9.24521704e-03],\n", | |
" ...,\n", | |
" [2.88347752e-02, 2.97344648e-01, 6.73820577e-01],\n", | |
" [6.99026288e-01, 2.92805754e-01, 8.16795839e-03],\n", | |
" [6.84561350e-01, 3.09447865e-01, 5.99078411e-03]],\n", | |
"\n", | |
" [[3.14927916e-01, 5.73638100e-01, 1.11433984e-01],\n", | |
" [1.82388659e-03, 2.94081405e-01, 7.04094708e-01],\n", | |
" [3.03359651e-01, 5.03803615e-01, 1.92836733e-01],\n", | |
" ...,\n", | |
" [1.25277388e-03, 2.33315838e-01, 7.65431388e-01],\n", | |
" [9.91532734e-02, 5.16850748e-01, 3.83995978e-01],\n", | |
" [7.18120246e-01, 2.78537920e-01, 3.34183389e-03]],\n", | |
"\n", | |
" [[1.22389131e-02, 2.89980880e-01, 6.97780207e-01],\n", | |
" [7.48693273e-01, 2.49804366e-01, 1.50236071e-03],\n", | |
" [7.40388030e-01, 2.56002389e-01, 3.60958146e-03],\n", | |
" ...,\n", | |
"...\n", | |
" ...,\n", | |
" [1.59798246e-03, 2.40037927e-01, 7.58364091e-01],\n", | |
" [5.78522467e-03, 2.42150429e-01, 7.52064346e-01],\n", | |
" [1.93062439e-01, 4.20348759e-01, 3.86588802e-01]],\n", | |
"\n", | |
" [[4.88716204e-03, 2.38253500e-01, 7.56859338e-01],\n", | |
" [3.58386812e-01, 5.16826179e-01, 1.24787008e-01],\n", | |
" [6.64768310e-01, 3.30687396e-01, 4.54429314e-03],\n", | |
" ...,\n", | |
" [5.13090849e-02, 4.65301898e-01, 4.83389018e-01],\n", | |
" [6.93597202e-01, 3.05439592e-01, 9.63206169e-04],\n", | |
" [6.90890885e-01, 3.07015807e-01, 2.09330822e-03]],\n", | |
"\n", | |
" [[8.97167636e-03, 3.26590274e-01, 6.64438049e-01],\n", | |
" [7.34400789e-01, 2.62958955e-01, 2.64025675e-03],\n", | |
" [1.81597565e-03, 2.58011987e-01, 7.40172038e-01],\n", | |
" ...,\n", | |
" [4.70693953e-01, 4.48879066e-01, 8.04269809e-02],\n", | |
" [3.60043640e-03, 2.39024250e-01, 7.57375314e-01],\n", | |
" [6.55709623e-01, 3.37543667e-01, 6.74671012e-03]]]]])\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * category (category) float64 0.0 1.0 2.0</pre>" | |
], | |
"text/plain": [ | |
"<xarray.DataArray (lead_time: 2, X: 5, Y: 5, week: 53, category: 3)>\n", | |
"array([[[[[2.84430615e-03, 2.48019553e-01, 7.49136141e-01],\n", | |
" [1.59098678e-02, 4.24464371e-01, 5.59625761e-01],\n", | |
" [6.43132743e-01, 3.47622040e-01, 9.24521704e-03],\n", | |
" ...,\n", | |
" [2.88347752e-02, 2.97344648e-01, 6.73820577e-01],\n", | |
" [6.99026288e-01, 2.92805754e-01, 8.16795839e-03],\n", | |
" [6.84561350e-01, 3.09447865e-01, 5.99078411e-03]],\n", | |
"\n", | |
" [[3.14927916e-01, 5.73638100e-01, 1.11433984e-01],\n", | |
" [1.82388659e-03, 2.94081405e-01, 7.04094708e-01],\n", | |
" [3.03359651e-01, 5.03803615e-01, 1.92836733e-01],\n", | |
" ...,\n", | |
" [1.25277388e-03, 2.33315838e-01, 7.65431388e-01],\n", | |
" [9.91532734e-02, 5.16850748e-01, 3.83995978e-01],\n", | |
" [7.18120246e-01, 2.78537920e-01, 3.34183389e-03]],\n", | |
"\n", | |
" [[1.22389131e-02, 2.89980880e-01, 6.97780207e-01],\n", | |
" [7.48693273e-01, 2.49804366e-01, 1.50236071e-03],\n", | |
" [7.40388030e-01, 2.56002389e-01, 3.60958146e-03],\n", | |
" ...,\n", | |
"...\n", | |
" ...,\n", | |
" [1.59798246e-03, 2.40037927e-01, 7.58364091e-01],\n", | |
" [5.78522467e-03, 2.42150429e-01, 7.52064346e-01],\n", | |
" [1.93062439e-01, 4.20348759e-01, 3.86588802e-01]],\n", | |
"\n", | |
" [[4.88716204e-03, 2.38253500e-01, 7.56859338e-01],\n", | |
" [3.58386812e-01, 5.16826179e-01, 1.24787008e-01],\n", | |
" [6.64768310e-01, 3.30687396e-01, 4.54429314e-03],\n", | |
" ...,\n", | |
" [5.13090849e-02, 4.65301898e-01, 4.83389018e-01],\n", | |
" [6.93597202e-01, 3.05439592e-01, 9.63206169e-04],\n", | |
" [6.90890885e-01, 3.07015807e-01, 2.09330822e-03]],\n", | |
"\n", | |
" [[8.97167636e-03, 3.26590274e-01, 6.64438049e-01],\n", | |
" [7.34400789e-01, 2.62958955e-01, 2.64025675e-03],\n", | |
" [1.81597565e-03, 2.58011987e-01, 7.40172038e-01],\n", | |
" ...,\n", | |
" [4.70693953e-01, 4.48879066e-01, 8.04269809e-02],\n", | |
" [3.60043640e-03, 2.39024250e-01, 7.57375314e-01],\n", | |
" [6.55709623e-01, 3.37543667e-01, 6.74671012e-03]]]]])\n", | |
"Coordinates:\n", | |
" * lead_time (lead_time) int64 1 2\n", | |
" * X (X) int64 0 1 2 3 4\n", | |
" * Y (Y) int64 0 1 2 3 4\n", | |
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
" * category (category) float64 0.0 1.0 2.0" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"def atomic_function_prediction_lr(classifiers, X_test):\n", | |
" try:\n", | |
" sample_size = np.prod(X_test.shape[:-1])\n", | |
" feature_size = X_test.shape[-1]\n", | |
" if len(X_test.shape)!=2:\n", | |
" print('reshape')\n", | |
" X_test = X_test.reshape(sample_size,features_size)\n", | |
" prediction = classifiers.predict_proba(X_test)[0]\n", | |
" prediction = xr.DataArray(prediction,dims='category')\n", | |
" prediction = prediction.assign_coords(category=classifiers.classes_) # doesnt stick\n", | |
" return prediction\n", | |
" except Exception as e: # set climatology instead\n", | |
" print(type(e).__name__,e)\n", | |
" n_classes = len(classifiers.classes_)\n", | |
" return xr.DataArray(np.repeat([1/n_classes,n_classes]),dims='category') # adapt repeat\n", | |
"\n", | |
"predictions = xr.apply_ufunc(atomic_function_prediction_lr,\n", | |
" all_classifiers,\n", | |
" X_test[features].to_array().transpose(...,'variable'),\n", | |
" input_core_dims=[[], [\"year\",'variable']], # adapt year\n", | |
" vectorize=True,\n", | |
" dask='parallelized',\n", | |
" output_core_dims=[['category']] # new dim for predict_proba\n", | |
" ).compute()\n", | |
"\n", | |
"# manually add new coords\n", | |
"predictions = predictions.assign_coords(category=all_classifiers.isel({i:0 for i in all_classifiers.dims}).item().classes_)\n", | |
"predictions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "xr", | |
"language": "python", | |
"name": "xr" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment