Created
February 27, 2018 11:12
-
-
Save nlgranger/fc0f5559a619641f66b0b279d4ae4121 to your computer and use it in GitHub Desktop.
tensorflow training script that does not generate graph.pbtxt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from procnet.utils import sensible_dir | |
from experiments.SpatialTransformerNetwork.model import spatial_transformer | |
from experiments.SpatialTransformerNetwork import dataset | |
def build_model_layers(inputs, nclasses, is_training): | |
initializer = tf.contrib.layers.xavier_initializer() | |
inputs = tf.layers.Input((28, 28, 1), tensor=inputs) | |
net = spatial_transformer(inputs, grid_dims=(1, 2)) | |
with tf.name_scope("conv1"): | |
net = tf.layers.conv2d( | |
net, 10, (5, 5), | |
activation=None, | |
kernel_initializer=initializer, | |
bias_initializer=tf.zeros_initializer) | |
tf.layers.max_pooling2d(net, 2, 2) | |
net = tf.nn.relu(net) | |
with tf.name_scope("conv2"): | |
net = tf.layers.conv2d( | |
net, 20, (5, 5), | |
activation=None, | |
kernel_initializer=initializer, | |
bias_initializer=tf.zeros_initializer) | |
if is_training: | |
net = tf.nn.dropout(net, 0.9) | |
net = tf.layers.max_pooling2d(net, 2, 2) | |
net = tf.nn.relu(net) | |
net = tf.layers.flatten(net) | |
with tf.name_scope("dense1"): | |
net = tf.layers.dense( | |
net, 50, | |
activation=None, | |
kernel_initializer=initializer, | |
bias_initializer=tf.zeros_initializer) | |
if is_training: | |
net = tf.nn.dropout(net, 0.9) | |
net = tf.nn.relu(net) | |
with tf.name_scope("dense2"): | |
logits = tf.layers.dense( | |
net, nclasses, | |
activation=None, | |
kernel_initializer=initializer, | |
bias_initializer=tf.zeros_initializer) | |
return logits | |
def train_input_fn(batch_size): | |
features, labels = dataset.train_x[:50000], dataset.train_y[:50000] | |
features = features.reshape((-1, 28, 28, 1)).astype(np.float32) | |
features = (features / 255 - 0.1307) / 0.3081 | |
labels = labels.reshape((-1,)).astype(np.int32) | |
pairs = tf.data.Dataset.from_tensor_slices(( | |
features, labels)) | |
pairs = pairs.shuffle(len(labels)).batch(batch_size).repeat() | |
return pairs.make_one_shot_iterator().get_next() | |
def eval_input_fn(batch_size): | |
features, labels = dataset.train_x[50000:], dataset.train_y[50000:] | |
features = features.reshape((-1, 28, 28, 1)).astype(np.float32) | |
features = (features / 255 - 0.1307) / 0.3081 | |
labels = labels.reshape((-1,)).astype(np.int32) | |
pairs = tf.data.Dataset.from_tensor_slices(( | |
features, labels)) | |
pairs = pairs.batch(batch_size).repeat() | |
return pairs.make_one_shot_iterator().get_next() | |
def main(): | |
batch_size = 64 | |
with tf.Session() as sess: | |
train_inputs, train_labels = train_input_fn(batch_size) | |
val_inputs, val_labels = eval_input_fn(batch_size) | |
with tf.variable_scope("inference", reuse=False): | |
train_logits = build_model_layers(train_inputs, 10, True) | |
train_loss = tf.losses.sparse_softmax_cross_entropy( | |
train_labels, train_logits) | |
tf.summary.scalar('loss', train_loss) | |
optimizer = tf.train.AdagradOptimizer(learning_rate=0.01) | |
train_op = optimizer.minimize( | |
train_loss, global_step=tf.train.get_global_step()) | |
with tf.variable_scope("inference", reuse=True): | |
val_logits = build_model_layers(val_inputs, 10, False) | |
val_loss = tf.losses.sparse_softmax_cross_entropy( | |
val_labels, val_logits) | |
tf.summary.scalar('loss', val_loss) | |
predicted_classes = tf.argmax(val_logits, 1) | |
accuracy_op = tf.metrics.accuracy(val_labels, predicted_classes) | |
model_dir = sensible_dir( | |
"experiments/SpatialTransformerNetwork/checkpoints", "run_") | |
train_writer = tf.summary.FileWriter(model_dir + "/train", sess.graph) | |
eval_writer = tf.summary.FileWriter(model_dir + "/eval") | |
merged = tf.summary.merge_all() | |
train_writer.flush() | |
sess.run(tf.global_variables_initializer()) | |
sess.run(tf.local_variables_initializer()) | |
for step in range(50000 // batch_size * 1): | |
summary, _ = sess.run([merged, train_op]) | |
train_writer.add_summary(summary, step) | |
if (step + 1) % 25 == 0: | |
summary, acc = sess.run( | |
[merged, accuracy_op]) | |
eval_writer.add_summary(summary, step) | |
train_writer.close() | |
eval_writer.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment