Last active
July 31, 2019 07:46
-
-
Save d02k01/7b10d53c1872dfbfefd91e40dd1247d0 to your computer and use it in GitHub Desktop.
ResNeXt50, ResNeXt101
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
git clone https://github.com/ocjosen/keras | |
cd keras/ | |
git checkout 48875c680a1f22dc26ece8acb37178557af67a38 | |
cd ../ | |
git clone https://github.com/keras-team/keras-applications | |
cd keras-applications/ | |
git checkout eaf8aed7e42ec568ac5760382f313df026826b1d | |
cd ../ | |
pip install numpy==1.16 tensorflow==1.13.1 ./keras/ ./keras-applications/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import keras | |
import numpy as np | |
from keras.applications import ResNeXt50, ResNeXt101 | |
def save_converted_model(model_name, include_top, verify=True): | |
keras.backend.clear_session() | |
target_layer_suffix = '_2_reduce' | |
channels_index = ( | |
3 if keras.backend.image_data_format() == 'channels_last' else 1) | |
if model_name == 'resnext50': | |
def build_fn(): | |
return ResNeXt50(weights='imagenet', include_top=include_top) | |
elif model_name == 'resnext101': | |
def build_fn(): | |
return ResNeXt101(weights='imagenet', include_top=include_top) | |
else: | |
raise ValueError( | |
'`model_name` can be `resnext50` or `resnext101` only.') | |
old_model = build_fn() | |
config = json.loads(old_model.to_json()) | |
old_weights_list = [layer.get_weights() for layer in old_model.layers] | |
del old_model | |
keras.backend.clear_session() | |
for layer_config in config['config']['layers']: | |
if target_layer_suffix in layer_config['name']: | |
layer_config['class_name'] = 'Conv2D' | |
filters = np.array( | |
layer_config['config']['function'][2][0]).shape[channels_index] | |
layer_config['config'] = keras.layers.Conv2D( | |
filters, 1, use_bias=False, trainable=False, | |
name=layer_config['name']).get_config() | |
new_model = keras.models.model_from_config(config) | |
del config | |
assert len(old_weights_list) == len(new_model.layers) | |
for old_weights, new_layer in zip(old_weights_list, new_model.layers): | |
if target_layer_suffix not in new_layer.name: | |
new_layer.set_weights(old_weights) | |
else: | |
_, _, in_channels, out_channels = new_layer.get_weights()[0].shape | |
filters = out_channels | |
c = in_channels // out_channels | |
kernel = np.zeros((1, 1, filters * c, filters), dtype='f') | |
for i in range(filters): | |
start = (i // c) * c * c + i % c | |
end = start + c * c | |
kernel[:, :, start:end:c, i] = 1 | |
new_layer.set_weights([kernel]) | |
del old_weights_list | |
if include_top: | |
file_name = model_name + '_weights_tf_dim_ordering_tf_kernels.h5' | |
else: | |
file_name = model_name + '_weights_tf_dim_ordering_tf_kernels_notop.h5' | |
new_model.save_weights(file_name) | |
if verify: | |
np.random.seed(42) | |
x = np.random.rand(1, 224, 224, 3).astype('f') | |
new_outputs = new_model.predict(x) | |
keras.backend.clear_session() | |
old_outputs = build_fn().predict(x) | |
keras.backend.clear_session() | |
assert np.array_equal(old_outputs, new_outputs) | |
if __name__ == '__main__': | |
for model_name in ['resnext50', 'resnext101']: | |
for include_top in [True, False]: | |
save_converted_model(model_name, include_top) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment