Created
April 2, 2019 09:04
-
-
Save ferrine/86682779389316b792052368fea814c5 to your computer and use it in GitHub Desktop.
solution to catalyst fp16 forward
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def map_nested(fn, structure, cond=lambda obj: isinstance(obj, torch.Tensor)): | |
r""" | |
Applies fn to an object in a possibly nested data structure and returns same | |
structure with every element changed if condition satisfied. | |
""" | |
def inner_map(obj): | |
if cond(obj): | |
return fn(obj) | |
if isinstance(obj, (tuple, list)) and len(obj) > 0: | |
return type(obj)(map(inner_map, obj)) | |
if isinstance(obj, dict) and len(obj) > 0: | |
return dict(map(inner_map, obj.items())) | |
return obj | |
# After map_nested is called, a inner_map cell will exist. This cell | |
# has a reference to the actual function inner_map, which has references | |
# to a closure that has a reference to the inner_map cell (because the | |
# fn is recursive). To avoid this reference cycle, we set the function to | |
# None, clearing the cell | |
try: | |
return inner_map(structure) | |
finally: | |
inner_map = None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment