Created
November 5, 2019 19:06
-
-
Save MikeOfZen/a26bc18850dc3203922c9e26b71c9b18 to your computer and use it in GitHub Desktop.
[Cut TF model] use this to slice a keras model in half. useful for transfer learning in large models #tf #python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_next_level(layer,model): | |
def wrap_list(val): | |
if type(val) is list: | |
return val | |
return [val] | |
r=[] | |
for output_t in wrap_list(layer.output): | |
r+=[x for x in model.layers if output_t.name in [y.name for y in wrap_list(x.input)]] | |
return r | |
def get_layers_above(cutoff_layer,model): | |
visited=set() | |
to_visit=set([cutoff_layer]) | |
while to_visit: | |
layer=to_visit.pop() | |
to_visit.update(get_next_level(layer,model)) | |
visited.add(layer) | |
return list(visited) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment