Last active
June 14, 2019 21:39
-
-
Save qubvel/8a6d23e485ebb1611b330bbfa534a6ac to your computer and use it in GitHub Desktop.
Code example for converting TF weights to Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
from keras.layers import Conv2D, BatchNormalization, Dense | |
# NOTE! | |
# It is supposed to be used with python 3.6+ as it is rely on ordered keys of dict | |
def get_name(name): | |
"""Parse name""" | |
parts = name.split('/')[:-1] | |
return '/'.join(parts) | |
def group_weights(weights): | |
""" | |
Group each layer weights together, initially all weights are dict of 'layer_name/layer_var': np.array | |
Example: | |
input: { | |
...: ... | |
'conv2d/kernel': <np.array>, | |
'conv2d/bias': <np.array>, | |
...: ... | |
} | |
output: [..., [...], [<conv2d/kernel-weights>, <conv2d/bias-weights>], [...], ...] | |
""" | |
out_weights = [] | |
previous_layer_name = '' | |
group = [] | |
for k, v in weights.items(): | |
layer_name = get_name(k) | |
if layer_name == previous_layer_name: | |
group.append(v) | |
else: | |
if group: | |
out_weights.append(group) | |
group = [v] | |
previous_layer_name = layer_name | |
out_weights.append(group) | |
return out_weights | |
def load_weights(model, weights): | |
"""Load weights to Conv2D, BatchNorm, Dense layers of model sequentially""" | |
i = 0 | |
for layer in model.layers: | |
if isinstance(layer, (Conv2D, BatchNormalization, Dense)): | |
print(layer) | |
layer.set_weights(groupped_weights[i]) | |
i += 1 | |
# read saved TF model weights | |
with open('../../checkpoints/{}/weights.pkl'.format(model_name), 'rb') as f: | |
weights = pickle.load(f) | |
# convert weights to keras format | |
groupped_weights = group_weights(weights) | |
# create model same as tf | |
model = EfficientNetB0(...) | |
# load weights layer by layer | |
load_weights(model, groupped_weights) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Code for dumping TF model weights (paste after checkpoint loading)