Skip to content

Instantly share code, notes, and snippets.

@prabindh
Created May 10, 2019 05:50
Show Gist options
  • Save prabindh/02ec8f083f376c294a3223ec686120cb to your computer and use it in GitHub Desktop.
Save prabindh/02ec8f083f376c294a3223ec686120cb to your computer and use it in GitHub Desktop.
# https://stackoverflow.com/questions/56069411
# The TF code
score_inputs = tf.placeholder(np.float32, shape=(None, 100))
targets = tf.placeholder(np.float32, shape=(None), name="targets")
l2 = tf.contrib.layers.l2_regularizer(0.01)
first_layer = tf.layers.dense(score_inputs, 100, activation=tf.nn.relu, kernel_regularizer=l2)
outputs = tf.layers.dense(first_layer, 1, activation = None, kernel_regularizer=l2)
optimizer = tf.train.AdamOptimizer(0.001)
l2_loss = tf.losses.get_regularization_loss()
loss = tf.reduce_mean(tf.square(tf.subtract(targets, outputs)))
loss += l2_loss
rmse = tf.sqrt(tf.reduce_mean(tf.square(outputs - targets)))
mae = tf.reduce_mean(tf.sqrt(tf.square(outputs - targets)))
training_op = optimizer.minimize(loss)
batch_size = 32
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(10):
avg_train_error = []
for i in range(len(train_x) // batch_size):
batch_x = train_x[i*batch_size: (i+1)*batch_size]
batch_y = train_y[i*batch_size: (i+1)*batch_size]
_, train_loss = sess.run([training_op, loss], {score_inputs: batch_x, targets: batch_y})
feed = {score_inputs: test_x, targets: test_y}
test_loss, test_mae, test_rmse, test_ouputs = sess.run([loss, mae, rmse, outputs], feed)
# The keras code
inputs = Input(shape=(100,))
hidden = Dense(100, activation="relu", kernel_regularizer = regularizers.l2(0.01))(inputs)
outputs = Dense(1, activation=None, kernel_regularizer = regularizers.l2(0.01))(hidden)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=keras.optimizers.Adam(lr=0.001), loss='mse', metrics=['mae'])
model.fit(train_x, train_y, batch_size=32, epochs=10, shuffle=False)
keras_pred = model.predict(test_x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment