Last active
June 26, 2018 13:03
-
-
Save solaris33/a3d9480c2069fdf64d0b16ba37c9a945 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""tf.estimator API를 이용한 TensorFlow Wide & Deep Tutorial 예제""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import shutil | |
from absl import app as absl_app | |
from absl import flags | |
import tensorflow as tf # pylint: disable=g-bad-import-order | |
from official.utils.flags import core as flags_core | |
from official.utils.logs import hooks_helper | |
from official.utils.logs import logger | |
from official.utils.misc import model_helpers | |
# 학습에 사용할 컬럼들을 정의합니다. | |
_CSV_COLUMNS = [ | |
'age', 'workclass', 'fnlwgt', 'education', 'education_num', | |
'marital_status', 'occupation', 'relationship', 'race', 'gender', | |
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', | |
'income_bracket' | |
] | |
# 컬럼들의 기본값을 지정합니다. | |
_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''], | |
[0], [0], [0], [''], ['']] | |
# training과 validation에 사용할 train validation 데이터 개수를 지정합니다. | |
_NUM_EXAMPLES = { | |
'train': 32561, | |
'validation': 16281, | |
} | |
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'} | |
def define_wide_deep_flags(): | |
"""model type과 학습을 위한 flag들을 지정합니다.""" | |
flags_core.define_base() | |
flags_core.define_benchmark() | |
flags.adopt_module_key_flags(flags_core) | |
flags.DEFINE_enum( | |
name="model_type", short_name="mt", default="wide_deep", | |
enum_values=['wide', 'deep', 'wide_deep'], | |
help="Select model topology.") | |
flags_core.set_defaults(data_dir='/tmp/census_data', | |
model_dir='/tmp/census_model', | |
train_epochs=40, | |
epochs_between_evals=2, | |
batch_size=40) | |
def build_model_columns(): | |
"""feature columns를 설정합니다.""" | |
# Continuous 컬럼들 | |
age = tf.feature_column.numeric_column('age') | |
education_num = tf.feature_column.numeric_column('education_num') | |
capital_gain = tf.feature_column.numeric_column('capital_gain') | |
capital_loss = tf.feature_column.numeric_column('capital_loss') | |
hours_per_week = tf.feature_column.numeric_column('hours_per_week') | |
education = tf.feature_column.categorical_column_with_vocabulary_list( | |
'education', [ | |
'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college', | |
'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school', | |
'5th-6th', '10th', '1st-4th', 'Preschool', '12th']) | |
marital_status = tf.feature_column.categorical_column_with_vocabulary_list( | |
'marital_status', [ | |
'Married-civ-spouse', 'Divorced', 'Married-spouse-absent', | |
'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed']) | |
relationship = tf.feature_column.categorical_column_with_vocabulary_list( | |
'relationship', [ | |
'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried', | |
'Other-relative']) | |
workclass = tf.feature_column.categorical_column_with_vocabulary_list( | |
'workclass', [ | |
'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov', | |
'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked']) | |
# To show an example of hashing: | |
occupation = tf.feature_column.categorical_column_with_hash_bucket( | |
'occupation', hash_bucket_size=1000) | |
# Transformations. | |
age_buckets = tf.feature_column.bucketized_column( | |
age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) | |
# Wide columns and deep columns. | |
base_columns = [ | |
education, marital_status, relationship, workclass, occupation, | |
age_buckets, | |
] | |
crossed_columns = [ | |
tf.feature_column.crossed_column( | |
['education', 'occupation'], hash_bucket_size=1000), | |
tf.feature_column.crossed_column( | |
[age_buckets, 'education', 'occupation'], hash_bucket_size=1000), | |
] | |
wide_columns = base_columns + crossed_columns | |
deep_columns = [ | |
age, | |
education_num, | |
capital_gain, | |
capital_loss, | |
hours_per_week, | |
tf.feature_column.indicator_column(workclass), | |
tf.feature_column.indicator_column(education), | |
tf.feature_column.indicator_column(marital_status), | |
tf.feature_column.indicator_column(relationship), | |
# To show an example of embedding | |
tf.feature_column.embedding_column(occupation, dimension=8), | |
] | |
return wide_columns, deep_columns | |
def build_estimator(model_dir, model_type): | |
"""estimator model type에 따른 estimator를 설정합니다.""" | |
wide_columns, deep_columns = build_model_columns() | |
hidden_units = [100, 75, 50, 25] | |
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which | |
# trains faster than GPU for this model. | |
run_config = tf.estimator.RunConfig().replace( | |
session_config=tf.ConfigProto(device_count={'GPU': 0})) | |
if model_type == 'wide': | |
return tf.estimator.LinearClassifier( | |
model_dir=model_dir, | |
feature_columns=wide_columns, | |
config=run_config) | |
elif model_type == 'deep': | |
return tf.estimator.DNNClassifier( | |
model_dir=model_dir, | |
feature_columns=deep_columns, | |
hidden_units=hidden_units, | |
config=run_config) | |
else: | |
return tf.estimator.DNNLinearCombinedClassifier( | |
model_dir=model_dir, | |
linear_feature_columns=wide_columns, | |
dnn_feature_columns=deep_columns, | |
dnn_hidden_units=hidden_units, | |
config=run_config) | |
def input_fn(data_file, num_epochs, shuffle, batch_size): | |
"""Estimator를 위한 input function을 정의합니다.""" | |
assert tf.gfile.Exists(data_file), ( | |
'%s not found. Please make sure you have run data_download.py and ' | |
'set the --data_dir argument to the correct path.' % data_file) | |
# csv 파일을 parsing합니다. | |
def parse_csv(value): | |
print('Parsing', data_file) | |
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) | |
features = dict(zip(_CSV_COLUMNS, columns)) | |
labels = features.pop('income_bracket') | |
return features, tf.equal(labels, '>50K') | |
# Extract lines from input files using the Dataset API. | |
dataset = tf.data.TextLineDataset(data_file) | |
if shuffle: | |
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train']) | |
dataset = dataset.map(parse_csv, num_parallel_calls=5) | |
# We call repeat after shuffling, rather than before, to prevent separate | |
# epochs from blending together. | |
dataset = dataset.repeat(num_epochs) | |
dataset = dataset.batch(batch_size) | |
return dataset | |
def export_model(model, model_type, export_dir): | |
"""SavedModel format으로 export합니다. | |
Args: | |
model: Estimator object | |
model_type: model type을 나타내는 string. e.g. "wide", "deep" or "wide_deep" | |
export_dir: model을 export할 폴더경로 | |
""" | |
wide_columns, deep_columns = build_model_columns() | |
if model_type == 'wide': | |
columns = wide_columns | |
elif model_type == 'deep': | |
columns = deep_columns | |
else: | |
columns = wide_columns + deep_columns | |
feature_spec = tf.feature_column.make_parse_example_spec(columns) | |
example_input_fn = ( | |
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)) | |
model.export_savedmodel(export_dir, example_input_fn) | |
def run_wide_deep(flags_obj): | |
"""Wide-Deep training과 evaluation을 실행합니다. | |
인자들(Args): | |
flags_obj: parsed 플래그들(flag) | |
""" | |
# model_dir 경로에 파일이 있으면 삭제합니다. | |
shutil.rmtree(flags_obj.model_dir, ignore_errors=True) | |
# tf.estimator API를 이용해서 학습모델을 생성합니다. | |
model = build_estimator(flags_obj.model_dir, flags_obj.model_type) | |
# 파일로부터 트레이닝 데이터와 테스트 데이터를 읽어옵니다. | |
train_file = os.path.join(flags_obj.data_dir, 'adult.data') | |
test_file = os.path.join(flags_obj.data_dir, 'adult.test') | |
# train을 진행하고 flags.epochs_between_evals epoch 때마다 evaluation을 진행합니다. | |
def train_input_fn(): | |
return input_fn( | |
train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size) | |
def eval_input_fn(): | |
return input_fn(test_file, 1, False, flags_obj.batch_size) | |
run_params = { | |
'batch_size': flags_obj.batch_size, | |
'train_epochs': flags_obj.train_epochs, | |
'model_type': flags_obj.model_type, | |
} | |
benchmark_logger = logger.get_benchmark_logger() | |
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params, | |
test_id=flags_obj.benchmark_test_id) | |
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '') | |
train_hooks = hooks_helper.get_train_hooks( | |
flags_obj.hooks, batch_size=flags_obj.batch_size, | |
tensors_to_log={'average_loss': loss_prefix + 'head/truediv', | |
'loss': loss_prefix + 'head/weighted_loss/Sum'}) | |
# train을 진행하고 flags.epochs_between_evals epoch 때마다 evaluation을 진행합니다. | |
for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals): | |
model.train(input_fn=train_input_fn, hooks=train_hooks) | |
results = model.evaluate(input_fn=eval_input_fn) | |
# evaluation 결과를 출력합니다. | |
tf.logging.info('Results at epoch %d / %d', | |
(n + 1) * flags_obj.epochs_between_evals, | |
flags_obj.train_epochs) | |
tf.logging.info('-' * 60) | |
for key in sorted(results): | |
tf.logging.info('%s: %s' % (key, results[key])) | |
benchmark_logger.log_evaluation_result(results) | |
if model_helpers.past_stop_threshold( | |
flags_obj.stop_threshold, results['accuracy']): | |
break | |
# 모델을 Export합니다. | |
if flags_obj.export_dir is not None: | |
export_model(model, flags_obj.model_type, flags_obj.export_dir) | |
def main(_): | |
with logger.benchmark_context(flags.FLAGS): | |
run_wide_deep(flags.FLAGS) | |
if __name__ == '__main__': | |
# 로깅 레벨을 설정합니다. | |
tf.logging.set_verbosity(tf.logging.INFO) | |
# 플래그들을 설정합니다. | |
define_wide_deep_flags() | |
# main 함수 호출합니다. | |
absl_app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment