Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save securetorobert/ddfd1e0b494bc94c9dcf86f858bfca04 to your computer and use it in GitHub Desktop.
Save securetorobert/ddfd1e0b494bc94c9dcf86f858bfca04 to your computer and use it in GitHub Desktop.
A DNNClassifier for iris dataset
feature_names = [
'SepalLength',
'SepalWidth',
'PetalLength',
'PetalWidth']
# Create the feature_columns, which specifies the input to our model.
# All our input features are numeric, so use numeric_column for each one.
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]
# Create a deep neural network regression classifier.
# Use the DNNClassifier pre-made estimator
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, # The input features to our model
hidden_units=[10, 10], # Two layers, each with 10 neurons
n_classes=3,
model_dir=PATH) # Path to where checkpoints etc are stored
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment