Last active
August 5, 2019 17:37
-
-
Save BryanCutler/68f14c6a9f1bcde2a22afaef55db0620 to your computer and use it in GitHub Desktop.
TensorFlow Arrow Blog Part 5 - Model Definition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_fit(ds): | |
"""Create and fit a Keras logistic regression model.""" | |
# Build the Keras model | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Dense(1, input_shape=(2,), | |
activation='sigmoid')) | |
model.compile(optimizer='sgd', loss='mean_squared_error', | |
metrics=['accuracy']) | |
# Fit the model on the given dataset | |
model.fit(ds, epochs=5, shuffle=False) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment