Last active
October 10, 2021 12:45
-
-
Save asimshankar/fb1f42c3bd91e1bb041f34a848e59fe1 to your computer and use it in GitHub Desktop.
TensorFlow: Saving and restoring variables in Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# Construct the graph | |
x = tf.Variable(1, name='x') | |
y = tf.Variable(2, name='y') | |
sum = tf.assign_add(x, y, name='sum') | |
# Add operations to save and restore checkpoints | |
saver = tf.train.Saver() | |
# Save the graph | |
with open('/tmp/graph.pb', 'w') as f: f.write(tf.get_default_graph().as_graph_def().SerializeToString()) | |
# Print out Go code snippet to save/restore | |
# Perhaps it may make sense for tf.Session to return a pointer | |
# to the tf.Graph it operates on instead of having to pass both | |
# the graph and session consistently. | |
sd = saver.saver_def | |
print(''' | |
// save saves the current value of variables in graph/sess in files with the | |
// given prefix and returns the string to provide to restore. | |
func save(graph *tf.Graph, sess *tf.Session, prefix string) (string, error) { | |
t, err := tf.NewTensor(prefix) | |
if err != nil { | |
return "", err | |
} | |
o := graph.Operation("%s").Output(0) | |
ret, err := sess.Run(map[tf.Output]*tf.Tensor{o:t}, []tf.Output{graph.Operation("%s").Output(0)}, nil) | |
if err != nil { | |
return "", err | |
} | |
return ret[0].Value().(string), nil | |
} | |
// restore restores the value of variables previously saved using save. | |
func restore(graph *tf.Graph, sess *tf.Session, path string) error { | |
t, err := tf.NewTensor(path) | |
if err != nil { | |
return err | |
} | |
o := graph.Operation("%s").Output(0) | |
_, err = sess.Run(map[tf.Output]*tf.Tensor{o:t}, nil, []*tf.Operation{graph.Operation("%s")}) | |
return err | |
} | |
''') % (sd.filename_tensor_name[:-2], sd.save_tensor_name[:-2], sd.filename_tensor_name[:-2], sd.restore_op_name) | |
# For fun, save the checkpoint where x=3 | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
sess.run(sum) | |
print "Saved to: " + saver.save(sess, "/tmp/ckpt1") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"io/ioutil" | |
"log" | |
tf "github.com/tensorflow/tensorflow/tensorflow/go" | |
) | |
func main() { | |
gdef, err := ioutil.ReadFile("/tmp/graph.pb") | |
if err != nil { | |
log.Fatal(err) | |
} | |
graph := tf.NewGraph() | |
if err := graph.Import(gdef, ""); err != nil { | |
log.Fatal(err) | |
} | |
sess, err := tf.NewSession(graph, nil) | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer sess.Close() | |
// Restore an existing checkpoint | |
if err := restore(graph, sess, "/tmp/ckpt1"); err != nil { | |
log.Fatal(err) | |
} | |
// Run an update and save a new checkpoint. | |
if _, err := sess.Run(nil, nil, []*tf.Operation{graph.Operation("sum")}); err != nil { | |
log.Fatal(err) | |
} | |
path, err := save(graph, sess, "/tmp/ckpt2") | |
if err != nil { | |
log.Fatal(err) | |
} | |
fmt.Println("Saved checkpoint to", path) | |
} | |
// Code below generated by the python script above | |
// save saves the current value of variables in graph/sess in files with the | |
// given prefix and returns the string to provide to restore. | |
func save(graph *tf.Graph, sess *tf.Session, prefix string) (string, error) { | |
t, err := tf.NewTensor(prefix) | |
if err != nil { | |
return "", err | |
} | |
o := graph.Operation("save/Const").Output(0) | |
ret, err := sess.Run(map[tf.Output]*tf.Tensor{o: t}, []tf.Output{graph.Operation("save/control_dependency").Output(0)}, nil) | |
if err != nil { | |
return "", err | |
} | |
return ret[0].Value().(string), nil | |
} | |
// restore restores the value of variables previously saved using save. | |
func restore(graph *tf.Graph, sess *tf.Session, path string) error { | |
t, err := tf.NewTensor(path) | |
if err != nil { | |
return err | |
} | |
o := graph.Operation("save/Const").Output(0) | |
_, err = sess.Run(map[tf.Output]*tf.Tensor{o: t}, nil, []*tf.Operation{graph.Operation("save/restore_all")}) | |
return err | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment