Last active
March 12, 2018 20:11
-
-
Save Dref360/5e1be92587dab430bdd4734a4f6670e3 to your computer and use it in GitHub Desktop.
Convert a Keras weight checkpoint train on multigpu to a regular Keras weight checkpoint
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from collections import defaultdict | |
| import h5py | |
| import numpy as np | |
| pjoin = os.path.join | |
| def convert_multi_to_single(multipath, fp): | |
| """ | |
| Convert a h5py weight file trained on a multigpu to a single-gpu. | |
| Notes: | |
| Tested for Xception | |
| Args: | |
| multipath: h5 filepath of the multigpu model | |
| fp: h5 filepath of the output | |
| """ | |
| f = h5py.File(multipath, 'r') | |
| if 'model_1' not in f: | |
| raise ValueError("`multipath` doesn't have a subnetwork called `'model_1'`.") | |
| ref = f['model_1'] | |
| out = h5py.File(fp, 'w') | |
| layer_name = [] | |
| weight_names = defaultdict(list) | |
| # Find all the weights name | |
| for w in ref.attrs['weight_names']: | |
| weight_names[w.decode('utf').split('/')[0]].append(w) | |
| # Assign the weight and the weights name to all layers. | |
| for k, v in ref.items(): | |
| name = k.split('/')[-1] | |
| layer_name.append(str.encode(name)) | |
| for k1, v1 in v.items(): | |
| out[pjoin(k, k, k1)] = v1.value | |
| out[name].attrs['weight_names'] = weight_names[name] | |
| # Set metadata | |
| out.attrs['keras_version'] = f.attrs['keras_version'] | |
| out.attrs['backend'] = str.encode('tensorflow') | |
| out.attrs['layer_names'] = np.array(layer_name) | |
| f.close() | |
| out.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment