Skip to content

Instantly share code, notes, and snippets.

@joshterrill
Last active January 20, 2018 19:27
Show Gist options
  • Save joshterrill/24915078ecd4d06db0a17cc753ba9303 to your computer and use it in GitHub Desktop.
Save joshterrill/24915078ecd4d06db0a17cc753ba9303 to your computer and use it in GitHub Desktop.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example code for TensorFlow Wide & Deep Tutorial using TF.Learn API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tempfile
import pandas as pd
from six.moves import urllib
import tensorflow as tf
CSV_COLUMNS = [
"age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
"capital_gain", "capital_loss", "hours_per_week", "native_country",
"income_bracket"
]
gender = tf.feature_column.categorical_column_with_vocabulary_list(
"gender", ["Female", "Male"])
education = tf.feature_column.categorical_column_with_vocabulary_list(
"education", [
"Bachelors", "HS-grad", "11th", "Masters", "9th",
"Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
"Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
"Preschool", "12th"
])
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
"marital_status", [
"Married-civ-spouse", "Divorced", "Married-spouse-absent",
"Never-married", "Separated", "Married-AF-spouse", "Widowed"
])
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
"relationship", [
"Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
"Other-relative"
])
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
"workclass", [
"Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
"Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
])
# To show an example of hashing:
occupation = tf.feature_column.categorical_column_with_hash_bucket(
"occupation", hash_bucket_size=1000)
native_country = tf.feature_column.categorical_column_with_hash_bucket(
"native_country", hash_bucket_size=1000)
# Continuous base columns.
age = tf.feature_column.numeric_column("age")
education_num = tf.feature_column.numeric_column("education_num")
capital_gain = tf.feature_column.numeric_column("capital_gain")
capital_loss = tf.feature_column.numeric_column("capital_loss")
hours_per_week = tf.feature_column.numeric_column("hours_per_week")
# Transformations.
age_buckets = tf.feature_column.bucketized_column(
age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
# Wide columns and deep columns.
base_columns = [
gender, education, marital_status, relationship, workclass, occupation,
native_country, age_buckets,
]
crossed_columns = [
tf.feature_column.crossed_column(
["education", "occupation"], hash_bucket_size=1000),
tf.feature_column.crossed_column(
[age_buckets, "education", "occupation"], hash_bucket_size=1000),
tf.feature_column.crossed_column(
["native_country", "occupation"], hash_bucket_size=1000)
]
deep_columns = [
tf.feature_column.indicator_column(workclass),
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(gender),
tf.feature_column.indicator_column(relationship),
# To show an example of embedding
tf.feature_column.embedding_column(native_country, dimension=8),
tf.feature_column.embedding_column(occupation, dimension=8),
age,
education_num,
capital_gain,
capital_loss,
hours_per_week,
]
def maybe_download(train_data, test_data):
"""Maybe downloads training data and returns train and test file names."""
if train_data:
train_file_name = train_data
else:
train_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve(
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
train_file.name) # pylint: disable=line-too-long
train_file_name = train_file.name
train_file.close()
print("Training data is downloaded to %s" % train_file_name)
if test_data:
test_file_name = test_data
else:
test_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve(
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test",
test_file.name) # pylint: disable=line-too-long
test_file_name = test_file.name
test_file.close()
print("Test data is downloaded to %s"% test_file_name)
return train_file_name, test_file_name
def build_estimator(model_dir, model_type):
"""Build an estimator."""
if model_type == "wide":
m = tf.estimator.LinearClassifier(
model_dir=model_dir, feature_columns=base_columns + crossed_columns)
elif model_type == "deep":
m = tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=[100, 50])
else:
m = tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=crossed_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=[100, 50])
return m
def input_fn(data_file, num_epochs, shuffle):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
y=labels,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5)
def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
"""Train and evaluate the model."""
train_file_name, test_file_name = maybe_download(train_data, test_data)
model_dir = tempfile.mkdtemp() if not model_dir else model_dir
m = build_estimator(model_dir, model_type)
# set num_epochs to None to get infinite stream of data.
m.train(
input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True),
steps=train_steps)
# set steps to None to run evaluation until all data consumed.
results = m.evaluate(
input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False),
steps=None)
print("model directory = %s" % model_dir)
for key in sorted(results):
print("%s: %s" % (key, results[key]))
FLAGS = None
estimator = build_estimator("./models", "wide_n_deep")
def predict_input_fn(data_file):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
return tf.estimator.inputs.pandas_input_fn(x=df_data, y=None, shuffle=False)
def main(_):
train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps,
FLAGS.train_data, FLAGS.test_data)
predict_file_name = "predict.csv"
predict_results = estimator(input_fn=predict_input_fn(predict_file_name))
for idx, prediction in enumerate(predict_results):
print(idx)
for key in prediction:
print("...{}: {}".format(key, prediction[key]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--model_dir",
type=str,
default="",
help="Base directory for output models."
)
parser.add_argument(
"--model_type",
type=str,
default="wide_n_deep",
help="Valid model types: {'wide', 'deep', 'wide_n_deep'}."
)
parser.add_argument(
"--train_steps",
type=int,
default=2000,
help="Number of training steps."
)
parser.add_argument(
"--train_data",
type=str,
default="",
help="Path to the training data."
)
parser.add_argument(
"--test_data",
type=str,
default="",
help="Path to the test data."
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment