Skip to content

Instantly share code, notes, and snippets.

@ferrine
Created April 2, 2019 09:04
Show Gist options
  • Save ferrine/86682779389316b792052368fea814c5 to your computer and use it in GitHub Desktop.
Save ferrine/86682779389316b792052368fea814c5 to your computer and use it in GitHub Desktop.
solution to catalyst fp16 forward
import torch
def map_nested(fn, structure, cond=lambda obj: isinstance(obj, torch.Tensor)):
r"""
Applies fn to an object in a possibly nested data structure and returns same
structure with every element changed if condition satisfied.
"""
def inner_map(obj):
if cond(obj):
return fn(obj)
if isinstance(obj, (tuple, list)) and len(obj) > 0:
return type(obj)(map(inner_map, obj))
if isinstance(obj, dict) and len(obj) > 0:
return dict(map(inner_map, obj.items()))
return obj
# After map_nested is called, a inner_map cell will exist. This cell
# has a reference to the actual function inner_map, which has references
# to a closure that has a reference to the inner_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return inner_map(structure)
finally:
inner_map = None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment