Skip to content

Instantly share code, notes, and snippets.

@allenyllee
Created October 6, 2017 06:08
Show Gist options
  • Save allenyllee/496bf8ebb3dbfdd75cbfb2eef4adfc4f to your computer and use it in GitHub Desktop.
Save allenyllee/496bf8ebb3dbfdd75cbfb2eef4adfc4f to your computer and use it in GitHub Desktop.
sample of tensorflow
"""Simple convolutional neural network classififer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.flags.FLAGS
def get_params():
"""Model params."""
return {
"drop_rate": 0.5
}
def model(features, labels, mode, params):
"""CNN classifier model."""
images = features["image"]
labels = labels["label"]
tf.summary.image("images", images)
drop_rate = params.drop_rate if mode == tf.estimator.ModeKeys.TRAIN else 0.0
features = images
for i, filters in enumerate([32, 64, 128]):
features = tf.layers.conv2d(
features, filters=filters, kernel_size=3, padding="same",
name="conv_%d" % (i + 1))
features = tf.layers.max_pooling2d(
inputs=features, pool_size=2, strides=2, padding="same",
name="pool_%d" % (i + 1))
features = tf.contrib.layers.flatten(features)
features = tf.layers.dropout(features, drop_rate)
features = tf.layers.dense(features, 512, name="dense_1")
features = tf.layers.dropout(features, drop_rate)
logits = tf.layers.dense(features, params.num_classes, activation=None,
name="dense_2")
predictions = tf.argmax(logits, axis=1)
loss = tf.losses.sparse_softmax_cross_entropy(
labels=labels, logits=logits)
return {"predictions": predictions}, loss
def eval_metrics(unused_params):
"""Eval metrics."""
return {
"accuracy": tf.contrib.learn.MetricSpec(tf.metrics.accuracy)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment