Created
May 21, 2018 09:58
-
-
Save EndingCredits/e5b29a62104bd31da705363e04848c78 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def combine_weights(in_list): | |
""" | |
Returns a 1D tensor of the input list of (nested lists of) tensors, useful | |
for doing things like comparing current weights with old weights for EWC. | |
1.) For all elements in input list, (ln 3) | |
if a list combine it recursively | |
else leave it alone | |
2.) From resulting list, get all non-none elements and flatten them (ln 2) | |
3.) If resulting list is empty return None (ln 1) | |
else return concatenation of list | |
( All on one line :) ) | |
""" | |
return (lambda x: None if not x else tf.concat(x, axis=0)) ( | |
[ tf.reshape(x, [-1]) for x in | |
[ combine_weights(x) if isinstance(x, list) else x for x in in_list ] | |
if x is not None]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment