Created
July 15, 2018 10:37
-
-
Save securetorobert/ddfd1e0b494bc94c9dcf86f858bfca04 to your computer and use it in GitHub Desktop.
A DNNClassifier for iris dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
feature_names = [ | |
'SepalLength', | |
'SepalWidth', | |
'PetalLength', | |
'PetalWidth'] | |
# Create the feature_columns, which specifies the input to our model. | |
# All our input features are numeric, so use numeric_column for each one. | |
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names] | |
# Create a deep neural network regression classifier. | |
# Use the DNNClassifier pre-made estimator | |
classifier = tf.estimator.DNNClassifier( | |
feature_columns=feature_columns, # The input features to our model | |
hidden_units=[10, 10], # Two layers, each with 10 neurons | |
n_classes=3, | |
model_dir=PATH) # Path to where checkpoints etc are stored |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment