Skip to content

Instantly share code, notes, and snippets.

@Dref360
Last active March 12, 2018 20:11
Show Gist options
  • Select an option

  • Save Dref360/5e1be92587dab430bdd4734a4f6670e3 to your computer and use it in GitHub Desktop.

Select an option

Save Dref360/5e1be92587dab430bdd4734a4f6670e3 to your computer and use it in GitHub Desktop.
Convert a Keras weight checkpoint train on multigpu to a regular Keras weight checkpoint
import os
from collections import defaultdict
import h5py
import numpy as np
pjoin = os.path.join
def convert_multi_to_single(multipath, fp):
"""
Convert a h5py weight file trained on a multigpu to a single-gpu.
Notes:
Tested for Xception
Args:
multipath: h5 filepath of the multigpu model
fp: h5 filepath of the output
"""
f = h5py.File(multipath, 'r')
if 'model_1' not in f:
raise ValueError("`multipath` doesn't have a subnetwork called `'model_1'`.")
ref = f['model_1']
out = h5py.File(fp, 'w')
layer_name = []
weight_names = defaultdict(list)
# Find all the weights name
for w in ref.attrs['weight_names']:
weight_names[w.decode('utf').split('/')[0]].append(w)
# Assign the weight and the weights name to all layers.
for k, v in ref.items():
name = k.split('/')[-1]
layer_name.append(str.encode(name))
for k1, v1 in v.items():
out[pjoin(k, k, k1)] = v1.value
out[name].attrs['weight_names'] = weight_names[name]
# Set metadata
out.attrs['keras_version'] = f.attrs['keras_version']
out.attrs['backend'] = str.encode('tensorflow')
out.attrs['layer_names'] = np.array(layer_name)
f.close()
out.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment