Created
January 13, 2017 06:40
-
-
Save ronekko/9783e1718db0aea6f1e369b937e4dcd5 to your computer and use it in GitHub Desktop.
全結合Linkを1×1のConvolution2Dに置き換える例
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Fri Jan 13 14:52:21 2017 | |
| @author: sakurai | |
| 平均値プーリングと全結合変換は、適用順序を入れ替えても同じ結果になる。 | |
| 全結合Linkを1×1のConvolution2Dに置き換える例。 | |
| """ | |
| import numpy as np | |
| import chainer.functions as F | |
| import chainer.links as L | |
| def pool_affine(fc, x): | |
| psize = x.shape[2:] | |
| h = F.average_pooling_2d(x, psize) | |
| y = fc(h) | |
| return h, y | |
| def affine_pool(conv, x): | |
| psize = x.shape[2:] | |
| h = conv(x) | |
| y = F.average_pooling_2d(h, psize) | |
| return h, y | |
| if __name__ == '__main__': | |
| shape = 1, 3, 2, 2 | |
| channels = shape[1] | |
| x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) | |
| print 'x:' | |
| print x | |
| # 大域的平均値プーリングしてからLinear | |
| fc = L.Linear(channels, 1) | |
| fc.W.data[:] = np.array([[1, 100, 10000]], dtype=np.float32) | |
| fc.b.data[:] = np.array([0.5], dtype=np.float32) | |
| h_pa, y_pa = pool_affine(fc, x) | |
| print '# pool -> affine ##########################' | |
| print 'h_pa:' | |
| print h_pa.data | |
| print 'y_pa:' | |
| print y_pa.data | |
| # Linearしてから大域的平均値プーリング | |
| conv = L.Convolution2D(in_channels=channels, out_channels=1, ksize=1) | |
| conv.W.data[:] = fc.W.data.reshape(1, channels, 1, 1) | |
| conv.b.data[:] = fc.b.data | |
| h_ap, y_ap = affine_pool(conv, x) | |
| print '# affine -> pool ##########################' | |
| print 'h_ap:' | |
| print h_ap.data | |
| print 'y_ap:' | |
| print y_ap.data | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment