Created
April 4, 2017 01:54
-
-
Save sneakers-the-rat/8f34876c57a829375e78de55e10e230e to your computer and use it in GitHub Desktop.
fix load_model() in keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## in model.load_model: | |
# add import | |
from .utils.io_utils import ask_to_proceed_with_overwrite, update_config | |
# At 225 add a version check and call an updating function | |
model_config = json.loads(model_config.decode('utf-8')) | |
k_version = float(f.attrs.get('keras_version')[0]) | |
if k_version < 2: | |
# Have to rename params to load old models | |
model_config = update_config(model_config) | |
model = model_from_config(model_config, custom_objects=custom_objects) | |
## in utils.io_utils: | |
# add imports | |
from ..legacy.interfaces import all_conversions, all_value_conversions, raise_duplicate_arg_error | |
# update function | |
def update_config(model_config): | |
"""Update outdated config parameters from <2.0 | |
# Arguments | |
model_config: json-decoded model_config from load_model() | |
# Returns | |
model_config: update model_config | |
""" | |
for l in model_config['config']['layers']: | |
# Specific to Convolution2D | |
if l['class_name'] == "Convolution2D": | |
l['config']['kernel_size'] = [l['config'].pop('nb_row'), | |
l['config'].pop('nb_col')] | |
if 'W_regularizer' in l['config'].keys(): | |
if l['config']['W_regularizer']: | |
l['config']['W_regularizer'].pop('name') | |
l1_val = l['config']['W_regularizer'].pop('l1') | |
l2_val = l['config']['W_regularizer'].pop('l2') | |
l['config']['W_regularizer'][u'class_name'] = u'L1L2' | |
l['config']['W_regularizer'][u'config'] = {u'l1': l1_val, | |
u'l2': l2_val} | |
for key in all_value_conversions: | |
if key in l['config']: | |
old_value = l['config'][key] | |
if old_value in all_value_conversions[key]: | |
l['config'][key] = all_value_conversions[key][old_value] | |
for old_name, new_name in all_conversions: | |
if old_name in l['config']: | |
value = l['config'].pop(old_name) | |
if new_name in l['config']: | |
raise_duplicate_arg_error(old_name, new_name) | |
l['config'][new_name] = value | |
return model_config | |
## in legacy.interfaces | |
# make list/dict of common conversions/value conversions | |
all_conversions = [('output_dim', 'units'), | |
('init', 'kernel_initializer'), | |
('W_regularizer', 'kernel_regularizer'), | |
('b_regularizer', 'bias_regularizer'), | |
('W_constraint', 'kernel_constraint'), | |
('b_constraint', 'bias_constraint'), | |
('bias', 'use_bias'), | |
('p', 'rate'), | |
('pool_length', 'pool_size'), | |
('stride', 'strides'), | |
('border_mode', 'padding'), | |
('sigma', 'stddev'), | |
('nb_filter', 'filters'), | |
('subsample', 'strides'), | |
('border_mode', 'padding'), | |
('dim_ordering', 'data_format'), | |
('init', 'kernel_initializer'), | |
('W_regularizer', 'kernel_regularizer'), | |
('b_regularizer', 'bias_regularizer'), | |
('W_constraint', 'kernel_constraint'), | |
('b_constraint', 'bias_constraint'), | |
('bias', 'use_bias'), | |
('input_dtype', 'dtype), | |
('beta_init', 'beta_initializer'), | |
('gamma_init', 'gamma_initializer')] | |
all_value_conversions = {'dim_ordering': {'tf': 'channels_last', | |
'th': 'channels_first', | |
'default': None}} | |
im having the same issue....
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I upgraded to Keras 2.0.7
I need to make prediction with models saved in Keras 1.2.2
When I load the saved trained mode,
I get an error on the input shape.
ValueError: Error when checking : expected convolution2d_5_input to have 4 dimensions, but got array with shape (4L, 33L, 33L)
I think this script solves a different problem,
but I tried to run it and get an error:
TypeError Traceback (most recent call last)
in ()
11 v = out_h5.attrs.get("model_config")
12 config = json.loads(v)
---> 13 for i, l in enumerate(config["config"]["layers"]):
14 dtype = l["config"].pop("input_dtype", None)
15 if dtype is not None:
TypeError: list indices must be integers, not str