Last active
October 17, 2017 18:50
-
-
Save delta2323/45e642dbbbc81baef0f686af5a056c5c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ConvolutionalAutoEncoder(chainer.Chain): | |
def __init__(self): | |
super(ConvolutionalAutoEncoder).__init__( | |
c1=L.Convolution2D(...), | |
c2=L.Convolution2D(...), | |
dc1=L.Deconvolution2D(...), | |
dc2=L.Deconvolution2D(...), | |
) | |
def convolve(self, x): | |
return self.c2(self.c1(x)) # optionally insert F.max_pooling_2d | |
def deconvolve(self, h): | |
return self.dc2(self.dc1(h)) | |
def __call__(self, x): | |
h = self.convolve(x) | |
x_hat = self.deconvolve(h) | |
# Instead of using L.Classifier, we calculate the loss value inside of the autoencoder. | |
# If we want to get reconstructed images, we need to implement a method like thist: | |
# | |
# def reconstruct(self, x): | |
# return self.deconvolve(self.convolve(x)) | |
# | |
# As we compare with mean_squared_error, x and x_hat must have shapes. | |
loss = F.mean_squared_error(x, x_hat) | |
return loss | |
model = ConvolutionalAutoEncoder() | |
# Create optimizer, updater, trainer as usual | |
trainer.run() | |
# convolve and deconvolve test images | |
images = get_test_batch() | |
convolved_images = model.convolve(images) # convolved_images is an instance of chainer.Variable, you can have access to raw data with .data attribute | |
deconvolved_images = model.deconvolve(convolved_images) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment