Skip to content

Instantly share code, notes, and snippets.

@lukegre
Created September 24, 2021 15:36
Show Gist options
  • Save lukegre/615cc5c7bd1ff2873f8e765e5ec022b8 to your computer and use it in GitHub Desktop.
Save lukegre/615cc5c7bd1ff2873f8e765e5ec022b8 to your computer and use it in GitHub Desktop.
Create global maps using xarray accessors
"""
Contains a function to quickly plot xarray datasets on a map
Loading the script creates a method for xr.DataArrays that can be used as follows:
da.mean('time').map()
Defaults can also be changed by changing values in the rcMaps dictionary.
I haven't figured out how this can be changed in notebooks, but you can just
change these with the **kwargs argument.
"""
import warnings
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from cartopy import crs
from functools import wraps
from matplotlib.pyplot import text
from matplotlib import MatplotlibDeprecationWarning
from copy import deepcopy
warnings.filterwarnings("ignore", ".*All-NaN slice encountered.*")
warnings.filterwarnings("ignore", ".*invalid value encountered in less.*")
warnings.filterwarnings("ignore", ".*convolution.*")
warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)
rcMaps = {
# these are specific to the map_subplot function
'proj': crs.PlateCarree(central_longitude=205),
'land_color': '#dddddd',
'coast_res': '110m',
'coast_lw': 0.5,
'robust': True,
'round': False,
# you can add any colorbar kwarg here and it will be set as the default
'colorbar.pad': 0.02,
'colorbar.fraction': 0.1,
}
def map_subplot(
pos=111,
proj=rcMaps['proj'],
round=rcMaps['round'],
land_color=rcMaps['land_color'],
coast_res=rcMaps['coast_res'],
**kwargs
):
"""
Makes an axes object with a cartopy projection for the current figure
Parameters
----------
pos: int/list [111]
Either a 3-digit integer or three separate integers
describing the position of the subplot. If the three
integers are *nrows*, *ncols*, and *index* in order, the
subplot will take the *index* position on a grid with *nrows*
rows and *ncols* columns. *index* starts at 1 in the upper left
corner and increases to the right.
*pos* is a three digit integer, where the first digit is the
number of rows, the second the number of columns, and the third
the index of the subplot. i.e. fig.add_subplot(235) is the same as
fig.add_subplot(2, 3, 5). Note that all integers must be less than
10 for this form to work.
proj: crs.Projection()
the cartopy coord reference system object to create the projection.
Defaults to crs.PlateCarree(central_longitude=205) if not given
round: bool [True]
If the projection is stereographic, round will cut the corners and
make the plot round
land_color: str ['w']
the color of the land patches
coast_res: str ['110m']
the resolution at which coastal lines are plotted. Valid options are
110m, 50m, 10m
**kwargs:
passed to fig.add_subplot(**kwargs)
Returns
-------
mpl.collections.QuadMesh:
A modified quadmesh object that contains the following classes
figure, axes, colorbar, set_title
"""
from cartopy import feature, crs
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
fig = plt.gcf()
is_default_width = fig.get_figwidth() == plt.rcParams['figure.figsize'][0]
is_default_height = fig.get_figheight() == plt.rcParams['figure.figsize'][1]
if is_default_width and is_default_height:
n_row = pos // 100
n_col = (pos - (n_row * 100)) // 10
width = n_col * 8
height = n_row * 3.5
fig.set_size_inches(width, height)
ax = fig.add_subplot(pos, projection=proj, **kwargs)
# makes maps round
stereo_maps = (
crs.Stereographic,
crs.NorthPolarStereo,
crs.SouthPolarStereo,
)
if isinstance(ax.projection, stereo_maps) & round:
theta = np.linspace(0, 2 * np.pi, 100)
center, radius = [0.5, 0.5], 0.475
verts = np.vstack([np.sin(theta), np.cos(theta)]).T
circle = mpath.Path(verts * radius + center)
ax.set_boundary(circle, transform=ax.transAxes)
# adds features
if coast_res == '110m':
land = ax.add_feature(feature.LAND, zorder=4, color=land_color)
else:
land = ax.add_feature(
feature.NaturalEarthFeature(
'physical', 'land', coast_res, facecolor=land_color))
ax.coastlines(resolution=coast_res, color='black', linewidth=rcMaps['coast_lw'], zorder=5)
ax.outline_patch.set_lw(rcMaps['coast_lw'])
ax.outline_patch.set_zorder(5)
return {'ax': ax, 'transform': crs.PlateCarree()}
def fill_lon_gap(xds):
import numpy as np
import numpy as np
if xds.lon.min() < -10:
x = np.arange(-180.5, 180)
else:
x = np.arange(0.5, 361)
xds = xds.sel(lon=x, method='nearest').assign_coords(lon=x)
return xds
@xr.register_dataarray_accessor("map")
class Mapping(object):
def __init__(self, xarray_obj):
self._obj = xarray_obj
@wraps(map_subplot)
def __call__(self, **kwargs):
"""Plot 2D data on a map. See map.pcolormesh for all call arguments"""
return self.pcolormesh(**kwargs)
def _plot(self, plot_func='pcolormesh', **kwargs):
from numpy import ndim
from cartopy import crs
da = self._obj
da = da.squeeze()
if ndim(da) != 2:
raise ValueError('Can only plot 2D arrays with maps')
da = da.assign_coords(lon=lambda x: x.lon%360).sortby('lon')
da = fill_lon_gap(da)
self._get_cbar_kwargs(kwargs)
map_kwargs = self._get_map_kwargs(kwargs)
props = dict(robust=rcMaps['robust'], **map_subplot(**map_kwargs))
props.update(kwargs)
img = getattr(da.plot, plot_func)(**props)
if hasattr(img, 'ax'):
img.axes = img.ax
self.axes = img.axes
img.set_title = self._text
return img
@wraps(map_subplot)
def contourf(self, **kwargs):
return self._plot(**kwargs, plot_func='contourf')
@wraps(map_subplot)
def pcolormesh(self, **kwargs):
return self._plot(**kwargs, plot_func='pcolormesh')
@wraps(map_subplot)
def contour(self, **kwargs):
return self._plot(**kwargs, plot_func='contour')
def _text(self, s, x=90, y=50, ha='center', va='center', weight='bold', size=12, **props):
"""
Write a title to the map, rather than above the map.
Will remove any axes titles. These can be returned with img.axes.set_title
Parameters
----------
s : str
the text that will be the title
x : float [90]
the longitude location of the text
y : float [50]
the latitude location of the text
For the remaining parameters, see plt.text
Returns
-------
plt.text object
"""
from cartopy import crs
kwargs = dict(transform=crs.PlateCarree(), zorder=30)
kwargs.update(**props, **dict(ha=ha, va=va, weight=weight, size=size))
self.axes.set_title('')
text = self.axes.text(x, y, s, **kwargs)
return text
@staticmethod
def _get_cbar_kwargs(kwargs):
cbar_defaults = {k.split('.')[1]: v for k, v in rcMaps.items() if k.startswith('colorbar')}
cbar_opts = cbar_defaults
if kwargs.get('add_colorbar', True):
if 'cbar_kwargs' in kwargs:
cbar_opts.update(kwargs['cbar_kwargs'])
kwargs['cbar_kwargs'] = cbar_opts
return kwargs
@staticmethod
def _get_map_kwargs(kwargs):
possible_kwargs = 'pos', 'land_color', 'proj', 'round'
map_kwargs = {k: v for k, v in kwargs.items() if k in possible_kwargs}
for k in map_kwargs:
kwargs.pop(k)
return map_kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment