Skip to content

Instantly share code, notes, and snippets.

@yairchu
Created November 25, 2024 10:31
Show Gist options
  • Save yairchu/126401ab7bacbd92c6382f537a15da1d to your computer and use it in GitHub Desktop.
Save yairchu/126401ab7bacbd92c6382f537a15da1d to your computer and use it in GitHub Desktop.
import autograd
import autograd.numpy as np
import numpy as orig_np
import xarray as xr
def patch_for_xarray_autograd_interop():
def autograd_nanmean(x, *args, dtype=None, **kwargs):
r = np.mean(x[np.isfinite(x)], *args, **kwargs)
if dtype is not None:
# Original np.nanmean calls `dtype.type(...)` which fails with autograd boxes,
# whereas `.astype(dtype)` works fine.
r = r.astype(dtype)
return r
# xarray's mean calls np.nanmean rather than looking for it in autograd.numpy.numpy_wrapper
orig_np.nanmean = autograd_nanmean
patch_for_xarray_autograd_interop()
def foo_xarray(x):
a = xr.DataArray(x ** np.linspace(0, 1, 20), dims=["time"])
b = xr.DataArray(x * np.linspace(2, 3, 20), dims=["space"])
c = xr.concat([a*b, a+b], dim="op")
assert c.data.dtype != object
return c.mean()
def foo_xarray_workaround(x):
# Due to result not being wrapped in a "box",
# autograd does not recognize it as something carrying a derivative.
return foo_xarray(x).data
print(f"{autograd.grad(foo_xarray_workaround)(5.0) = }")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment