Last active
August 29, 2015 14:20
-
-
Save piyo7/b4f49821248b80f680a0 to your computer and use it in GitHub Desktop.
ディープラーニング勉強会 AutoEncoder ref: http://qiita.com/piyo7/items/60576759430910ffe5be
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.math | |
import scala.util.Random | |
import MatrixImplicits._ // 自作のimplicit classによるSeqラッパ | |
class dA(train_N: Int, n_visible: Int, n_hidden: Int, seed: Int) { | |
val rng = new Random(seed) | |
var W = Seq.fill(n_hidden, n_visible)(uniform(-1.0 / n_visible, 1.0 / n_visible)) | |
var hbias = Seq.fill(n_hidden)(0.0) | |
var vbias = Seq.fill(n_visible)(0.0) | |
def uniform(min: Double, max: Double): Double = rng.nextDouble() * (max - min) + min | |
def sigmoid(x: Double): Double = 1.0 / (1.0 + math.exp(-x)) | |
def corrupted(x: Seq[Double], p: Double): Seq[Double] = x.map(_ * (if (rng.nextDouble() < p) 0.0 else 1.0)) | |
def encode(x: Seq[Double]): Seq[Double] = ((W mXc x) + hbias).map(sigmoid) | |
def decode(y: Seq[Double]): Seq[Double] = ((W.T mXc y) + vbias).map(sigmoid) | |
def train(x: Seq[Double], learning_rate: Double, corruption_level: Double) { | |
val tilde_x = corrupted(x, corruption_level) | |
val y = encode(tilde_x) | |
val z = decode(y) | |
val L_vbias = x - z | |
val L_hbias = (W mXc L_vbias) * y * y.map(1.0 - _) | |
vbias = vbias + L_vbias.map(_ * learning_rate / train_N) | |
hbias = hbias + L_hbias.map(_ * learning_rate / train_N) | |
W = W + ((L_hbias cXr tilde_x) + (y cXr L_vbias)).map2(_ * learning_rate / train_N) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment