Created
April 8, 2017 01:49
-
-
Save oiehot/e1806306ed8b62551de63bdc028da652 to your computer and use it in GitHub Desktop.
오차역전파법을 위한 덧셈층, 곱셈층
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AddLayer: | |
def __init__(self): | |
pass | |
def forward(self, x, y): | |
return x + y | |
def backward(self, dout): | |
dx = dout * 1 | |
dy = dout * 1 | |
return dx, dy | |
class MulLayer: | |
def __init__(self): | |
self.x = None | |
self.y = None | |
def forward(self, x, y): | |
self.x = x | |
self.y = y | |
return x * y | |
def backward(self, dout): | |
dx = dout * self.y | |
dy = dout * self.x | |
return dx, dy | |
apple_price = 100 | |
apple_num = 2 | |
orange_price = 150 | |
orange_num = 3 | |
tax = 1.1 | |
print( "사과값:", apple_price ) # 100 | |
print( "사과개수:", apple_num ) # 2 | |
print( "오렌지값:", orange_price ) # 150 | |
print( "오렌지개수:", orange_num ) # 3 | |
print( "세금:", tax ) # 1.1 | |
apple_mul_layer = MulLayer() | |
orange_mul_layer = MulLayer() | |
add_apple_orange_layer = AddLayer() | |
tax_mul_layer = MulLayer() | |
apple_total_price = apple_mul_layer.forward(apple_price, apple_num) | |
orange_total_price = orange_mul_layer.forward(orange_price, orange_num) | |
total_price = add_apple_orange_layer.forward(apple_total_price, orange_total_price) | |
tax_price = tax_mul_layer.forward(total_price, tax) | |
print( "세전가:", total_price ) # 650 | |
print( "세후가:", tax_price ) # 715 | |
d_price = 1 | |
d_all_price, d_tax = tax_mul_layer.backward(d_price) | |
d_apple_price, d_orange_price = add_apple_orange_layer.backward(d_all_price) | |
d_apple, d_apple_num = apple_mul_layer.backward(d_apple_price) | |
d_orange, d_orange_num = orange_mul_layer.backward(d_orange_price) | |
print( "세후가/세전가 미분:", d_all_price ) # 1.1 | |
print( "세후가/세금 미분:", d_tax ) # 650 | |
print( "세후가/전체사과값 미분:", d_apple_price ) # 1.1 | |
print( "세후가/전체오렌지값 미분:", d_orange_price ) # 1.1 | |
print( "세후가/사과값 미분:", d_apple ) # 2.2 | |
print( "세후가/사과개수 미분:", d_apple_num ) # 110 | |
print( "세후가/오렌지값 미분:", d_orange ) # 3.3 | |
print( "세후가/오렌지개수 미분:", d_orange_num ) # 165 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment