Created
May 17, 2017 07:27
-
-
Save ytbilly3636/ed75da0029e10b00cac0ccfaeb40f5ec to your computer and use it in GitHub Desktop.
ChainerでSOM?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import cv2 | |
import numpy as np | |
import chainer | |
import chainer.links as L | |
from chainer import Chain, Variable | |
from chainer import datasets | |
class SOM(Chain): | |
# SOMをchainer.links.Linearで表現 | |
# 競合層は2次元マップなので2次元→1次元変換を行いながら計算する | |
# widthは2次元マップの幅 全部で競合層のニューロンはwidth x width個 | |
def __init__(self, width): | |
self.width = width | |
super(SOM, self).__init__( | |
competitive = L.Linear(in_size=None, out_size=self.width * self.width, nobias=True) | |
) | |
return | |
# 勝者決定アルゴリズム | |
# 入力ベクトル(x)に最も類似する重みベクトルを持つニューロンを探す | |
# 類似度は内積によって計算する | |
# 返り値はニューロンのマップ上の座標 | |
def predict(self, x): | |
ip = self.competitive(x) | |
pos = np.argmax(ip.data) | |
return pos/self.width, pos%self.width | |
# 近傍関数算出アルゴリズム | |
# ガウス関数によって定義 | |
# centerはガウス関数の中心座標、varはガウス関数の分散 | |
def __neighbor(self, center, var): | |
y = np.abs(np.arange(-center[0], self.width - center[0])) | |
x = np.abs(np.arange(-center[1], self.width - center[1])) | |
xx, yy = np.meshgrid(x, y) | |
d2 = xx**2 + yy**2 | |
return np.exp(- d2 / (2 * (var**2))) | |
# インクリメンタル方式の学習(←→バッチ方式) | |
# xは入力ベクトル、lrは学習係数、varはガウス関数の分散 | |
# reinforceは強化・減衰の指定 今回は実装していないがここを活用すればLVQ1が実装可能 | |
def incr_learn(self, x, lr=0.1, var=4.0, reinforce=True): | |
pos = self.predict(x) | |
delta_x = (lr * self.__neighbor(pos, var).reshape(1, -1)).T.dot(x.data) | |
delta_w = (lr * self.__neighbor(pos, var).reshape(1, -1)).T * self.competitive.W.data | |
self.competitive.W.data += delta_x if reinforce else - delta_x | |
self.competitive.W.data -= delta_w if reinforce else - delta_w | |
return | |
# マップの可視化 | |
# 入力ベクトルが正方形の画像であるときのみ利用可能(汎用性なし) | |
# in_widthは入力画像の幅、chは入力画像のチャンネル | |
def weight_show(self, in_width, ch): | |
show_array = np.zeros((in_width*self.width, in_width*self.width, ch), dtype=np.float32) | |
for i, c in enumerate(self.competitive.W.data): | |
y = i / self.width | |
x = i % self.width | |
if ch == 3: | |
show_array[y*in_width:(y+1)*in_width, x*in_width:(x+1)*in_width] = cv2.cvtColor(np.rollaxis(c.reshape(ch, in_width, in_width), 0, 3), cv2.COLOR_RGB2BGR) | |
else: | |
show_array[y*in_width:(y+1)*in_width, x*in_width:(x+1)*in_width] = c.reshape(in_width, in_width) | |
cv2.imshow('win', show_array) | |
cv2.waitKey(1) | |
return | |
# 10x10のマップを用意 | |
som = SOM(width=10) | |
# 入力としてMNISTデータセットを使用 | |
# CIFAR10でも可 | |
train, test = datasets.get_mnist() | |
# 1バッチを学習 | |
for it, tr in enumerate(train): | |
if it % 5000 == 0: | |
print 'iter:', it | |
# 入力ベクトルを用意 | |
x = Variable(np.array([tr[0]], dtype=np.float32)) | |
# 学習率とガウス関数の分散は徐々に小さくしていく | |
lr = 0.05 * (1.0 - float(it) / len(train)) | |
var = 2.0 * (1.0 - float(it) / len(train)) | |
# 学習 | |
som.incr_learn(x, lr=lr, var=var) | |
# 可視化 | |
som.weight_show(in_width=28, ch=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment