Skip to content

Instantly share code, notes, and snippets.

Created December 20, 2015 00:28
Show Gist options
  • Save lukedeo/d1899f011ae41b26fb6e to your computer and use it in GitHub Desktop.
Save lukedeo/d1899f011ae41b26fb6e to your computer and use it in GitHub Desktop.
Saving a Keras model.
hacked out style keras NN saving functionality
[credit] deepdish
from __future__ import division, print_function, absolute_import
import numpy as np
import tables
import warnings
import sys
import six
from keras.models import model_from_json
# Types that should be saved as pytables attribute
ATTR_TYPES = (int, float, bool, six.string_types,
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.bool_, np.complex64, np.complex128)
COMPRESSION = tables.Filters(complevel=9, complib='blosc', shuffle=True)
except Exception:
warnings.warn("Missing BLOSC: no compression will be used.")
COMPRESSION = tables.Filters()
def _save_level(handler, group, level, name=None, compress=True):
if isinstance(level, dict):
# First create a new group
new_group = handler.create_group(group, name,
for k, v in level.items():
if isinstance(k, six.string_types):
_save_level(handler, new_group, v, name=k)
# Key is not string, so it gets a bit more complicated.
# If the key is not a string, we will store it as a tuple instead,
# inside a new group
hsh = hash(k)
if hsh < 0:
hname = 'm{}'.format(-hsh)
hname = '{}'.format(hsh)
new_group2 = handler.create_group(new_group, '__pair_{}'.format(hname),
new_name = '__pair_{}'.format(hname)
_save_level(handler, new_group2, k, name='key')
_save_level(handler, new_group2, v, name='value')
#new_name = '__keyvalue_pair_{}'.format(hash(name))
#setattr(group._v_attrs, new_name, (name, level))
elif isinstance(level, list):
# Lists can contain other dictionaries and numpy arrays, so we don't
# want to serialize them. Instead, we will store each entry as i0, i1,
# etc.
new_group = handler.create_group(group, name,
for i, entry in enumerate(level):
level_name = 'i{}'.format(i)
_save_level(handler, new_group, entry, name=level_name)
elif isinstance(level, tuple):
# Lists can contain other dictionaries and numpy arrays, so we don't
# want to serialize them. Instead, we will store each entry as i0, i1,
# etc.
new_group = handler.create_group(group, name,
for i, entry in enumerate(level):
level_name = 'i{}'.format(i)
_save_level(handler, new_group, entry, name=level_name)
elif isinstance(level, np.ndarray):
atom = tables.Atom.from_dtype(level.dtype)
if compress:
node = handler.create_carray(group, name, atom=atom,
node = handler.create_array(group, name, atom=atom,
node[:] = level
elif isinstance(level, ATTR_TYPES):
setattr(group._v_attrs, name, level)
elif level is None:
# Store a None as an empty group
new_group = handler.create_group(group, name, "nonetype:")
warnings.warn('( Pickling', level, ': '
'This may cause incompatiblities (for instance between '
'Python 2 and 3) and should ideally be avoided')
node = handler.create_vlarray(group, name, tables.ObjectAtom())
def _load_level(level):
if isinstance(level, tables.Group):
dct = {}
# Load sub-groups
for grp in level:
lev = _load_level(grp)
n = grp._v_name
# Check if it's a complicated pair or a string-value pair
if n.startswith('__pair'):
dct[lev['key']] = lev['value']
dct[n] = lev
# Load attributes
for name in level._v_attrs._f_list():
v = level._v_attrs[name]
if isinstance(v, np.string_):
v = v.decode('utf-8')
dct[name] = v
if level._v_title.startswith('list:'):
N = int(level._v_title[len('list:'):])
lst = []
for i in range(N):
return lst
elif level._v_title.startswith('tuple:'):
N = int(level._v_title[len('tuple:'):])
lst = []
for i in range(N):
return tuple(lst)
elif level._v_title.startswith('nonetype:'):
return None
return dct
elif isinstance(level, tables.VLArray):
if level.shape == (1,):
return level[0]
return level[:]
elif isinstance(level, tables.Array):
return level[:]
def save(path, data, compress=True):
Save any Python structure to an HDF5 file. It is particularly suited for
Numpy arrays. This function works similar to ````, except if you
save a Python object at the top level, you do not need to issue
``data.flat[1]`` to retrieve it from inside a Numpy array of type
Four types of objects get saved natively in HDF5, the rest get serialized
automatically. For most needs, you should be able to stick to the four,
which are:
* Dictionaries
* Lists and tuples
* Basic data types (including strings and None)
* Numpy arrays
A recommendation is to always convert your data to using only these four
ingredients. That way your data will always be retrievable by any HDF5
reader. A class that helps you with this is `deepdish.util.Saveable`.
This function requires the [PyTables] module to be installed.
path : file-like object or string
File or filename to which the data is saved.
data : anything
Data to be saved. This can be anything from a Numpy array, a string, an
object, or a dictionary containing all of them including more
compress : boolean
Turn off data compression.
See also
if not isinstance(path, str):
path =
h5file = tables.open_file(path, mode='w')
# If the data is a dictionary, put it flatly in the root
if isinstance(data, dict):
group = h5file.root
for key, value in data.items():
_save_level(h5file, group, value, name=key, compress=compress)
group = h5file.root
_save_level(h5file, group, data, name='_top', compress=compress)
def load(path, unpack=True):
Loads an HDF5 saved with `save`.
This function requires the [PyTables] module to be installed.
path : file-like object or string
File or filename from which to load the data.
unpack : bool
If True, a single-entry dictionaries will be unpacked and the value
will be returned directly. That is, if you save ``dict(a=100)``, only
``100`` will be loaded.
data : anything
Hopefully an identical reconstruction of the data that was saved.
See also
if not isinstance(path, str):
path =
h5file = tables.open_file(path, mode='r')
root = h5file.root
data = _load_level(h5file.root)
# Unpack if top is the only one
if isinstance(data, dict) and len(data) == 1:
if '_top' in data:
data = data['_top']
elif unpack:
data = data.values()[0]
return data
def save_network(net, filename):
_data = {
'config' : net.to_json(),
'weights' : net.get_weights()
save(filename, _data, compress=True)
def load_network(filename, dtype='float32'):
data = load(filename)
net = model_from_json(data['config'])
# W = [w.astype(dtype) for w in data['weights']]
W = data['weights']
return net
Copy link

This is really nice. It cut my model size in 1/3rd!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment