Skip to content

Instantly share code, notes, and snippets.

@Swalloow
Last active April 22, 2018 08:59
Show Gist options
  • Save Swalloow/c90dea8bce9ceee147851656f48fac9f to your computer and use it in GitHub Desktop.
Save Swalloow/c90dea8bce9ceee147851656f48fac9f to your computer and use it in GitHub Desktop.
class Model:
def __init__(self, data, target):
data_size = int(data.get_shape()[1])
target_size = int(target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(data, weight) + bias
self._prediction = tf.nn.softmax(incoming)
cross_entropy = -tf.reduce_sum(target, tf.log(self._prediction))
self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
mistakes = tf.not_equal(
tf.argmax(target, 1), tf.argmax(self._prediction, 1))
self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
@property
def prediction(self):
return self._prediction
@property
def optimize(self):
return self._optimize
@property
def error(self):
return self._error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment