Created
October 12, 2018 12:34
-
-
Save solaris33/3c9d2f59b6a0ee87699e26adc1a1b511 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
import tensorflow as tf | |
def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000): | |
def _input_fn(): | |
images_batch, labels_batch = tf.train.shuffle_batch( | |
tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)], | |
batch_size=batch_size, | |
capacity=capacity, | |
min_after_dequeue=min_after_dequeue, | |
enqueue_many=True, | |
num_threads=4) | |
features_map = {'images': images_batch} | |
return features_map, labels_batch | |
return _input_fn | |
# MNIST 데이터를 로드합니다. | |
data = tf.contrib.learn.datasets.mnist.load_mnist() | |
train_input_fn = get_input_fn(data.train, batch_size=256) | |
eval_input_fn = get_input_fn(data.validation, batch_size=5000) | |
import time | |
# 인풋을 사용할 특징값과 Kernel Mapping과 LinearClassifier를 결합한 KernelLinearClassifier를 선언합니다. | |
image_column = tf.contrib.layers.real_valued_column('images', dimension=784) | |
optimizer = tf.train.FtrlOptimizer( | |
learning_rate=50.0, l2_regularization_strength=0.001) | |
kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper( | |
input_dim=784, output_dim=2000, stddev=5.0, name='rffm') | |
kernel_mappers = {image_column: [kernel_mapper]} | |
estimator = tf.contrib.kernel_methods.KernelLinearClassifier( | |
n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers) | |
# 학습을 진행하고, 학습 시간을 출력합니다. | |
start = time.time() | |
estimator.fit(input_fn=train_input_fn, steps=2000) | |
end = time.time() | |
print('Elapsed time: {} seconds'.format(end - start)) | |
# 5000장의 test 데이터로 학습된 KernelLinearClassifier의 성능을 측정합니다. | |
eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1) | |
print(eval_metrics) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment