Created
November 25, 2024 10:31
-
-
Save yairchu/126401ab7bacbd92c6382f537a15da1d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import autograd | |
import autograd.numpy as np | |
import numpy as orig_np | |
import xarray as xr | |
def patch_for_xarray_autograd_interop(): | |
def autograd_nanmean(x, *args, dtype=None, **kwargs): | |
r = np.mean(x[np.isfinite(x)], *args, **kwargs) | |
if dtype is not None: | |
# Original np.nanmean calls `dtype.type(...)` which fails with autograd boxes, | |
# whereas `.astype(dtype)` works fine. | |
r = r.astype(dtype) | |
return r | |
# xarray's mean calls np.nanmean rather than looking for it in autograd.numpy.numpy_wrapper | |
orig_np.nanmean = autograd_nanmean | |
patch_for_xarray_autograd_interop() | |
def foo_xarray(x): | |
a = xr.DataArray(x ** np.linspace(0, 1, 20), dims=["time"]) | |
b = xr.DataArray(x * np.linspace(2, 3, 20), dims=["space"]) | |
c = xr.concat([a*b, a+b], dim="op") | |
assert c.data.dtype != object | |
return c.mean() | |
def foo_xarray_workaround(x): | |
# Due to result not being wrapped in a "box", | |
# autograd does not recognize it as something carrying a derivative. | |
return foo_xarray(x).data | |
print(f"{autograd.grad(foo_xarray_workaround)(5.0) = }") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment