- train your model in keras in python
- use serialize() to write the graph with parameters out to a file
- use load-graph to load that file using the tensorflow java api
Last active
August 31, 2017 20:18
-
-
Save bobpoekert/55136024048075989d283192badac0a0 to your computer and use it in GitHub Desktop.
How to load keras graphs from clojure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns tf | |
(:require [byte-streams :as bs]) | |
(:import [org.tensorflow Graph Session])) | |
(defn load-graph | |
[inf] | |
(let [^Graph g (Graph.)] | |
(.importGraphDef g (bs/to-byte-array inf)) | |
g)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import keras.backend as K | |
from keras.models import Model | |
from tensorflow.python.framework import graph_util | |
from tensorflow.python.framework import graph_io | |
from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos | |
def serialize(model_getter, output_path): | |
""" | |
model_getter: function that returns a keras Model | |
output_path: filepath to save the tensorflow pb model to | |
""" | |
session = tf.Session() | |
K.set_session(session) | |
K.set_learning_phase(0) | |
model = model_getter() | |
config = model.get_config() | |
weights = model.get_weights() | |
input_graph = session.graph_def | |
output_graph = graph_util.convert_variables_to_constants( | |
session, input_graph, [v.name.split(':')[0] for v in model.outputs]) | |
with open(output_path, 'w') as outf: | |
outf.write(output_graph.SerializeToString()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment