Skip to content

Instantly share code, notes, and snippets.

@bobpoekert
Last active August 31, 2017 20:18
Show Gist options
  • Save bobpoekert/55136024048075989d283192badac0a0 to your computer and use it in GitHub Desktop.
Save bobpoekert/55136024048075989d283192badac0a0 to your computer and use it in GitHub Desktop.
How to load keras graphs from clojure
  1. train your model in keras in python
  2. use serialize() to write the graph with parameters out to a file
  3. use load-graph to load that file using the tensorflow java api
(ns tf
(:require [byte-streams :as bs])
(:import [org.tensorflow Graph Session]))
(defn load-graph
[inf]
(let [^Graph g (Graph.)]
(.importGraphDef g (bs/to-byte-array inf))
g))
import tensorflow as tf
import keras.backend as K
from keras.models import Model
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos
def serialize(model_getter, output_path):
"""
model_getter: function that returns a keras Model
output_path: filepath to save the tensorflow pb model to
"""
session = tf.Session()
K.set_session(session)
K.set_learning_phase(0)
model = model_getter()
config = model.get_config()
weights = model.get_weights()
input_graph = session.graph_def
output_graph = graph_util.convert_variables_to_constants(
session, input_graph, [v.name.split(':')[0] for v in model.outputs])
with open(output_path, 'w') as outf:
outf.write(output_graph.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment