Last active
October 30, 2022 01:52
-
-
Save drscotthawley/81865a5c5e729b769486efb9c3f2249d to your computer and use it in GitHub Desktop.
Wrapper to give einops.rearrange an "inverse"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from einops import rearrange as _rearrange | |
class RearrangeWrapper(): | |
"wrapper to endow einops.rearrange with an 'inverse' operation" | |
def __init__(self): | |
self.shape, self.s = None, None # just in case someone tries to call inverse first | |
def __call__(self, x, s:str, **kwargs): # this 'forward' call is lightweight to preserve original usage | |
self.shape, self.s = x.shape, s | |
return _rearrange(x, s, **kwargs) | |
def inverse(self, | |
y, # torch tensor, e.g., result of forward call | |
infer_dim='', # axis-letter (from self.s) to try to infer | |
): | |
assert ((self.shape is not None) and (self.s is not None)), "inverse called before forward method" | |
split = self.s.split('->') # get 'before' and 'after' strings of forward transform | |
axes = split[0].strip().split(' ') # get axis letters, assuming they're space-separated before '->' | |
assert len(axes) == len(self.shape) | |
axes_info = {axes[i]:self.shape[i] for i in range(len(self.shape)) } | |
if infer_dim in axes_info.keys(): axes_info.pop(infer_dim) | |
return _rearrange(y, split[1]+' -> '+split[0], **axes_info) | |
# only have to instantiate this once for the rest of the code, even with different parameters/dims | |
rearrange = RearrangeWrapper() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
One limitation of the above implementation is if you want to "invert" more than one call "into the past", it doesn't support that:
This could be perhaps be fixed by having the forward call optionally also return some "archival" info that could be passed in later to inverse. 🤷 For example:
Output: