Last active
February 19, 2017 04:07
-
-
Save oiehot/3da4da64caa07142eb41301ea197eed3 to your computer and use it in GitHub Desktop.
평가함수, 평균제곱오차 & 교차엔트로피오차
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os | |
import numpy as np | |
# 정답표, One hot label | |
t = np.array([ | |
[0,0,1,0,0,0,0,0,0,0], # 2 | |
[0,0,0,1,0,0,0,0,0,0], # 3 | |
[0,0,0,0,0,1,0,0,0,0], # 5 | |
[1,0,0,0,0,0,0,0,0,0], # 0 | |
[0,1,0,0,0,0,0,0,0,0] # 1 | |
]) | |
# 정답일 확률 | |
y = np.array([ | |
[.0, .0, .8, .0, .0, .0, .0, .0, .0, .0], | |
[.0, .0, .0, .5, .0, .0, .0, .0, .0, .0], | |
[.0, .0, .0, .0, .0, .3, .0, .0, .0, .0], | |
[.7, .0, .0, .0, .0, .0, .0, .0, .0, .0], | |
[.0, .2, .0, .0, .0, .0, .0, .0, .0, .0] | |
]) | |
# 평균 제곱 오차mean squared error, MSE | |
def meanSquaredError(y, t): | |
return 0.5 * np.sum( (y-t)**2 ) | |
# 교차 엔트로피 오차cross entropy error, CEE | |
def crossEntropyError(y, t): | |
return -np.sum( t * np.log(y+1e-7) ) | |
def crossEntropyError_total(y, t): | |
total_size = y.shape[0] | |
if y.ndim == 1: # 항목이 하나인 1차원 배열인 경우, 2차원 배열로 변경. | |
y.reshape(1, total_size) | |
t.reshape(1, total_size) | |
return -np.sum( t * np.log(y+1e-7) ) / total_size | |
def crossEntropyError_batch(y, t, batch_size): | |
total_size = y.shape[0] | |
if batch_size > total_size: | |
batch_size = total_size | |
if y.ndim == 1: | |
y.reshape(1, total_size) | |
t.reshape(1, total_size) | |
batch_y = y[ np.arange(batch_size) ] | |
batch_t = t[ np.arange(batch_size) ] | |
return -np.sum( batch_t * np.log(batch_y+1e-7) ) / batch_size | |
print( crossEntropyError(y, t) ) | |
print( crossEntropyError_total(y, t) ) | |
print( crossEntropyError_batch(y, t, 3) ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment