Created
November 11, 2017 13:11
-
-
Save oiehot/40cd04305862b9dc346911f3bfede02f to your computer and use it in GitHub Desktop.
Tensorflow 심층신경망(DNN)을 이용하여 꽃 분류하기
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Tensorflow 심층신경망(DNN)을 이용하여 꽃 분류하기 (Classification) | |
순서: | |
1. CSV로 부터 Iris 훈련/시험 데이터를 읽는다. | |
2. 분류하는 신경망을 만든다. | |
3. 데이터를 통한 훈련. | |
4. 새로운 샘플을 통해 판별하기. | |
SL SW PL PW species | |
1 5.1 3.5 1.4 0.2 0(setosa) | |
2 4.9 3.0 1.4 0.2 0(setosa) | |
3 4.7 3.2 1.3 0.2 0(setosa) | |
4 4.6 3.1 1.5 0.2 0(setosa) | |
5 5.0 3.6 1.4 0.2 0(setosa) | |
6 5.4 3.9 1.7 0.4 0(setosa) | |
col[0]: Sepal Length (꽃받침의 길이) | |
col[1]: Sepal Width (꽃받침의 너비) | |
col[2]: Petal Length (꽃잎의 길이) | |
col[3]: Petal Width (꽃일의 너비) | |
col[4]: Species | |
0. Iris setosa (부채붓꽃) | |
1. Iris versicolor (북방푸른꽃창포) | |
2. Iris virginica | |
* Iris data set: 통계학자인 피셔Fisher가 소개한 데이터. 피셔는 통계학자, 유전학자, 진화생물학자로 현대 통계학에서 지대한 공을 세운 학자다. | |
* Iris: 붓꽃과의 한 속. 200~300 종. 영어로 아이리스라 부르며 무지개를 뜻하는 '이리스'에서 왔다. | |
''' | |
import os | |
import tensorflow as tf | |
import numpy as np | |
import pandas as pd | |
from urllib.request import urlopen | |
IRIS_TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv" | |
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv" | |
IRIS_TRAIN_FILE = 'iris_training.csv' | |
IRIS_TEST_FILE = 'iris_test.csv' | |
# 데이터를 다운로드 받는다. | |
def download(url, filename): | |
if not os.path.exists(filename): | |
raw = urlopen(url).read() | |
with open(filename, 'wb') as f: | |
f.write(raw) | |
print('다운로드 %s => %s' % (url, filename)) | |
download(IRIS_TRAIN_URL, IRIS_TRAIN_FILE) | |
download(IRIS_TEST_URL, IRIS_TEST_FILE) | |
# 다운로드된 CSV를 통해 Dataset 인스턴스를 얻는다. | |
# Dataset은 네임드 튜플이다. (data, target) | |
train_set = tf.contrib.learn.datasets.base.load_csv_with_header( | |
filename=IRIS_TRAIN_FILE, | |
target_dtype=np.int, | |
features_dtype=np.float32) | |
test_set = tf.contrib.learn.datasets.base.load_csv_with_header( | |
filename=IRIS_TEST_FILE, | |
target_dtype=np.int, | |
features_dtype=np.float32) | |
# 모든 특성(Features, SL SW PL PW)이 실제 값을 가진 데이터라고 지정한다. | |
feature_columns = [tf.feature_column.numeric_column('x', shape=[4])] | |
# 10, 20, 10 단위의 3층 DNN을 만든다. (DNNClassifier) | |
classifier = tf.estimator.DNNClassifier( | |
feature_columns=feature_columns, | |
hidden_units=[10, 20, 10], | |
n_classes=3, | |
model_dir='d:/tmp' | |
) | |
# * hidden_units: 은닉층 | |
# * n_classes: 타겟 클래스, 여기서는 3개의 품종 중 하나를 나타낸다. | |
# * model_dir: 텐서플로우가 훈련중에 데이터를 저장하는 사용하는 디렉토리 경로. | |
# 훈련용 데이터 입력 함수 | |
train_input_fn = tf.estimator.inputs.numpy_input_fn( | |
x={'x': np.array(train_set.data)}, | |
y=np.array(train_set.target), | |
num_epochs=None, | |
shuffle=True | |
) | |
# 훈련용 데이터를 이용해서 DNN신경망을 훈련시킨다. (train) | |
print('훈련중...') | |
classifier.train(input_fn=train_input_fn, steps=2000) | |
print('훈련완료.') | |
# 시험용 데이터를 이용한 입력 함수를 준비한다. | |
test_input_fn = tf.estimator.inputs.numpy_input_fn( | |
x={'x': np.array(test_set.data)}, | |
y=np.array(test_set.target), | |
num_epochs=1, | |
shuffle=False | |
) | |
# 훈련된 신경망을 사용하여 테스트 데이터의 정확도를 계산한다. (evaluate) | |
score = classifier.evaluate(input_fn=test_input_fn)['accuracy'] | |
print('현재 정확도: %f' % score) | |
## 추론해보기 | |
# 추론용 데이터 | |
# samples = np.array( | |
# [[6.4, 3.2, 4.5, 1.5], | |
# [5.8, 3.1, 5.0, 1.7]], | |
# dtype=np.float32) | |
samples = np.random.random(size=(100,4)) * 5 | |
# 추론용 데이터 입력함수 | |
predict_input_fn = tf.estimator.inputs.numpy_input_fn( | |
x={'x': samples}, | |
num_epochs=1, | |
shuffle=False) | |
# 추론하기 | |
predictions = list(classifier.predict(input_fn=predict_input_fn)) | |
# 추론결과 표시 | |
predicted_classes = [p['classes'] for p in predictions] | |
result = np.hstack((samples, predicted_classes)) | |
df = pd.DataFrame(result, columns=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']) | |
mapping = {b'0':'setosa', b'1':'versicolor', b'2':'virginica'} | |
final = df.replace({'species':mapping}) | |
print(final) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment