Skip to content

Instantly share code, notes, and snippets.

@fsausset
Created March 31, 2017 14:43
Show Gist options
  • Save fsausset/57b99a3db5e1a05569845894ec385eef to your computer and use it in GitHub Desktop.
Save fsausset/57b99a3db5e1a05569845894ec385eef to your computer and use it in GitHub Desktop.
Export a Keras model to a tensorflow .pb file with embedded weights to use on Android.
from keras.models import Sequential
from keras.models import model_from_json
from keras import backend as K
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
import os
# Load existing model.
with open("model.json",'r') as f:
modelJSON = f.read()
model = model_from_json(modelJSON)
model.load_weights("model_weights.hdf5")
# All new operations will be in test mode from now on.
K.set_learning_phase(0)
# Serialize the model and get its weights, for quick re-building.
config = model.get_config()
weights = model.get_weights()
# Re-build a model where the learning phase is now hard-coded to 0.
new_model = Sequential.from_config(config)
new_model.set_weights(weights)
temp_dir = "graph"
checkpoint_prefix = os.path.join(temp_dir, "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# Temporary save graph to disk without weights included.
saver = tf.train.Saver()
checkpoint_path = saver.save(K.get_session(), checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)
tf.train.write_graph(K.get_session().graph, temp_dir, input_graph_name)
input_graph_path = os.path.join(temp_dir, input_graph_name)
input_saver_def_path = ""
input_binary = False
output_node_names = "Softmax" # model dependent
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(temp_dir, output_graph_name)
clear_devices = False
# Embed weights inside the graph and save to disk.
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")
@fearlessfara
Copy link

fearlessfara commented Apr 30, 2019

I am getting the same error as @elgunguliyev7 but I can't find a way to handle it, any suggestions?

@nihalkenkre
Copy link

Hello @faraoneking99. It has been a while since I used Tensorflow, so I do not remember the function names.

You can list out all the nodes in the frozen graph. Check the name of the node corresponding to the last node in your keras model. e.g. dense/softmax_2 or output/softmax_2. And use that string in the output_node_names variable.

@fearlessfara
Copy link

fearlessfara commented May 2, 2019

Thanks @amadlover Is it possible that I have more output nodes? (my model has different output nodes and the number of those can continuously be changed since there is no assigned number of classes to be recognized). I'm using a customized version of the VGG19, retrained for my classes and i found that the output node name (I'm not sure if it's correct) is predictions/Softmax. If u need further info i can attach the code and/or the list of the frozen graph nodes. Assume that i need to use the net on openCV so if u have any other advice to give me on that I'd like to get some<3

@nihalkenkre
Copy link

@faraoneking99 you could try assigning output_node_name to predictions/Softmax.

Even if the number of output classes changes the model should still remain the same. (I might be wrong). The same model can be used to train on different sets of data.

Let me know.

@fearlessfara
Copy link

fearlessfara commented May 2, 2019

@amadlover in some way i managed to get it to work using predictions/Softmax, btw the model changes as I add more classes (i pass the lenght of the array classes (where i have all the classes name stored) as final number of neurons in the output layer) Imma copy paste the code here so u can get a better general overview
``

-- coding: utf-8 --

"""
Created on Sat Apr 13 18:06:05 2019

@author: chris
"""

Import libraries

import numpy as np
import matplotlib.pyplot as plt
import keras
import itertools
import convert
import os

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import tensorflow as tf
from keras import backend as K
K.set_image_dim_ordering('th')
from sklearn.metrics import classification_report,confusion_matrix
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.optimizers import SGD,RMSprop,Adam
from keras.metrics import categorical_crossentropy
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import *
import pickle

#%%

sess = tf.Session()

K.set_session(sess)

#%%
train_path = 'model/train'
valid_path = 'model/valid'
test_path = 'model/test'

wkdir = ''

pb_filename = 'pesi_pb.pb'
pbtxt_filename = 'pesi_pb.pbtxt'

#%%

with open ('model/config/config.cfg', 'rb') as fp:
classi = pickle.load(fp)
print("Classi caricate correttamente")

#%%
for i in classi:
questa = 'model/config/' + i + '.cfg'
with open(questa, 'rb') as fp:
batchSizes = pickle.load(fp)

#%%

#trainSize = int(batchSizes[0]/4)
#testSize = int(batchSizes[1]/4)
#validSize = int(batchSizes[2]/4)

#ATTENZIONE!!! LEGGERE ATTENTAMENTE
#Generally batch size of 32 or 25 is good, with epochs = 100 unless you have large dataset.
#in case of large dataset you can go with batch size of 10 with epochs b/w 50 to 100.
#DATASET MOLTO GRANDE, USIAMO UNA BATCH SIZE PICCOLA
trainSize = 10
testSize = 10
validSize = 10

print("train size totale: " + str(trainSize))
print("test size totale: " + str(testSize))
print("valid size totale: " + str(validSize))

train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(224,224), classes=classi, batch_size=trainSize)
test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(224,224), classes=classi, batch_size=testSize)
valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(224,224), classes=classi, batch_size=validSize)

#%%
stepsEpoche = len(train_batches)/10
stepsValid = len(valid_batches)/10
numEpoche = 3
#%%
#stampa immagini con le varie label
def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
if type(ims[0]) is np.ndarray:
ims = np.array(ims).astype(np.uint8)
if (ims.shape[-1] != 3):
ims = ims.transpose((0,2,3,1))
f = plt.figure(figsize=figsize)
cols = len(ims)//rows if len(ims) % 2 == 0 else len(ims)//rows +1
for i in range(len(ims)):
sp = f.add_subplot(rows, cols, i+1)
sp.axis('Off')
if titles is not None:
sp.set_title(titles[i], fontsize=16)
plt.imshow(ims[i], interpolation=None if interp else 'none')

#%%

imgs, labels = next(train_batches)

#%%
plots(imgs, titles=labels)

#%%
test_imgs, test_labels = next(test_batches)
plots(test_imgs, titles=test_labels)

#%%
test_labels = test_labels[:,0]
test_labels

#%%
#VGG19 IMPORT

vgg19_model = keras.applications.vgg19.VGG19()
vgg19_model.summary()

#ELIMINA L'ULTIMO LAYER CON I 1000 OUTPUT
vgg19_model.layers.pop()

model = Sequential()

for layer in vgg19_model.layers:
model.add(layer)

for layer in model.layers:
layer.trainable = False

model.add(Dense(len(classi), activation='softmax', name='predictions'))

model.compile(Adam(lr=.0001), loss='categorical_crossentropy', metrics=['accuracy'])

hist = model.fit_generator(train_batches, steps_per_epoch=stepsEpoche, validation_data=valid_batches, validation_steps=stepsValid, epochs=numEpoche, verbose=2)

test_imgs, test_labels = next(test_batches)
plots(test_imgs, titles=test_labels)
test_labels= test_labels[:,0]
test_labels
predictions = model.predict_generator(test_batches, steps=1, verbose=0)

print("Un po di statistiche: ")
acc = hist.history['acc']
#print("accuracy " + str(acc))
val_acc = hist.history['val_acc']
#print("val accuracy " + str(val_acc))
loss = hist.history['loss']
#print("loss " + str(loss))
val_loss = hist.history['val_loss']
#print("val loss " + str(val_loss))

stats = []
stats.append(acc)
stats.append(val_acc)
stats.append(loss)
stats.append(val_loss)
#salvo statistiche su un file
with open('model/stats.txt', 'wb') as fp:
pickle.dump(stats, fp)
print("Statistiche per questa sessione di learning salvate sul file stats.txt")
epochs = range(1, len(acc) + 1)

"bo" is for "blue dot"

plt.plot(epochs, loss, 'bo', label='Training loss')

b is for "solid blue line"

plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

plt.plot(epochs, acc, 'bo', label='Accuracy')

b is for "solid blue line"

plt.plot(epochs, val_acc, 'b', label='Validation cccuracy')
plt.title('Accuracy and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.show()

#%%
#SERIALIZZA MODELLO IN UN FILE JSON

Save JSON config to disk

json_config = model.to_json()
with open('model_config.json', 'w') as json_file:
json_file.write(json_config)
json_file.close()

Save weights to disk

model.save_weights('weights.h5')
model.save('model.h5')

print("Pesi del modello salvati sul disco come 'weights.h5' ")
print("Modello salvati sul disco come 'model.h5' ")
print("JSON del modello salvati sul disco come 'model_config.json' ")

#%%
import freeze_graph

All new operations will be in test mode from now on.

K.set_learning_phase(0)

Serialize the model and get its weights, for quick re-building.

config = model.get_config()
weights = model.get_weights()

Re-build a model where the learning phase is now hard-coded to 0.

new_model = Sequential.from_config(config)
new_model.set_weights(weights)

temp_dir = "graph"
checkpoint_prefix = os.path.join(temp_dir, "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"

Temporary save graph to disk without weights included.

saver = tf.train.Saver()
checkpoint_path = saver.save(K.get_session(), checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)
tf.train.write_graph(K.get_session().graph, temp_dir, input_graph_name)

input_graph_path = os.path.join(temp_dir, input_graph_name)
input_saver_def_path = ""
input_binary = False
output_node_names = "predictions/Softmax" # model dependent
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(temp_dir, output_graph_name)
clear_devices = False
#%%
sess = K.get_session()
graph_def = sess.graph.as_graph_def()
for node in graph_def.node:
print(node)
#%%

Embed weights inside the graph and save to disk.

freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")``

@fearlessfara
Copy link

I then use optimize_for_inference.py to optimize the graph but when i finally try to load the net into opencv using this code:
cvNet = cv.dnn.readNetFromTensorflow('opt_model.pb')

I get this error:
Traceback (most recent call last):
File "C:/Users/chris/PycharmProjects/flask/opencv.py", line 9, in
cvOut = cvNet.forward()
cv2.error: OpenCV(4.0.0) C:\projects\opencv-python\opencv\modules\dnn\src\dnn.cpp:413: error: (-2:Unspecified error) Can't create layer "flatten/Shape" of type "Shape" in function 'cv::dnn::dnn4_v20180917::LayerData::getLayerInstance'

Process finished with exit code 1

@fearlessfara
Copy link

I then use optimize_for_inference.py to optimize the graph but when i finally try to load the net into opencv using this code:
cvNet = cv.dnn.readNetFromTensorflow('opt_model.pb')

I get this error:
Traceback (most recent call last):
File "C:/Users/chris/PycharmProjects/flask/opencv.py", line 9, in
cvOut = cvNet.forward()
cv2.error: OpenCV(4.0.0) C:\projects\opencv-python\opencv\modules\dnn\src\dnn.cpp:413: error: (-2:Unspecified error) Can't create layer "flatten/Shape" of type "Shape" in function 'cv::dnn::dnn4_v20180917::LayerData::getLayerInstance'

Process finished with exit code 1

@nihalkenkre
Copy link

Hello, I have exhausted the extent my help on this matter. :|

@fearlessfara
Copy link

@amadlover , thanks the same, I hope someone maybe will read this and give me an opinion/solution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment