Skip to content

Instantly share code, notes, and snippets.

@dlwh
Last active August 29, 2015 13:57
Show Gist options
  • Save dlwh/9734216 to your computer and use it in GitHub Desktop.
Save dlwh/9734216 to your computer and use it in GitHub Desktop.
client code
object NMF {
def supervised(W: CuMatrix[Float], X: CuMatrix[Float], iters: Int = 200, eps: Float = 1E-6f) = {
require(W.rows == X.rows)
import W.blas
val n = X.rows
val m = X.cols
val r = W.cols
var H = CuMatrix.ones[Float](r, m)
val Wones = W.t * CuMatrix.ones[Float](n, 1) * CuMatrix.ones[Float](1, m)
for(i <- 0 until iters) {
println(i)
H = H :*= (W.t * (X :/ (W * H))) :/= Wones
max.inPlace(H, eps)
System.gc()
}
H
}
def main(args: Array[String]) {
import jcuda.jcublas._
implicit val handle = new cublasHandle
JCublas2.cublasCreate(handle)
val in = System.currentTimeMillis()
val W = CuMatrix.ones[Float](1024, 108)
val H = CuMatrix.ones[Float](108, 10000)
val X = W * H
val Hhat = supervised(W, X)
val Xhat = W * Hhat
val out = System.currentTimeMillis()
println(max(abs(Xhat.toDense - X.toDense)))
println(s"${(out - in)/1000.0} seconds")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment