Skip to content

Instantly share code, notes, and snippets.

@windoze
Created May 16, 2018 13:50
Show Gist options
  • Save windoze/ee3dae3015a4e40d4929b2ef5b05c427 to your computer and use it in GitHub Desktop.
Save windoze/ee3dae3015a4e40d4929b2ef5b05c427 to your computer and use it in GitHub Desktop.
package cn.azure.chatbot.classifier
import com.github.jfasttext.JFastText
import com.huaban.analysis.jieba.JiebaSegmenter
import com.huaban.analysis.jieba.JiebaSegmenter.SegMode
import com.intel.analytics.bigdl.nn.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import org.slf4j.LoggerFactory
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
object Classifier {
private val log = LoggerFactory.getLogger("Classifier")
private val SEQUENCE_LEN=500
private val DIMENSIONS=300
private val PADDING_VEC: List[Array[Float]]=List.fill(SEQUENCE_LEN)(Array.fill(DIMENSIONS)(0))
// TODO:
private val CATEGORIES = Array("", "CCS", "IoT", "Resource health", "SQL Database", "active-directory", "analysis-services", "api-management", "app-service", "application-gateway", "automation", "azure-portal", "azure-resource", "azure-resource-manager", "backup", "batch", "billing", "cdn", "cosmos-db", "downloads", "languages", "machine-learning", "multiple", "mysql", "open-resource", "scheduler", "security", "site-recovery", "storage", "virtual-machines", "virtual-network")
private val cutter = new JiebaSegmenter()
private var jft: JFastText = _
// = {
// val j = new JFastText()
// j.loadModel("cc.zh.300.bin")
// j
// }
//private val model = Module.loadModule[Float]("./faqmodel.bigdl", "faqmodel.bin")
private var model: AbstractModule[Activity, Activity, Float] = _
private var top: Int = 10
private def docToWords(s: String): List[String] = {
val ret=cutter.process(s, SegMode.SEARCH).toList.map(_.word)
log.debug(s"docToWords($s) => $ret")
ret
}
private def wordsToVecs(words: List[String]) = {
val ret = words.map(jft.getWordVector(_).asScala.map(_.toFloat).toArray[Float])
log.debug(s"wordsToVec('$words') => $ret")
ret
}
private def docToVecs(s: String): List[Array[Float]] = (wordsToVecs(docToWords(s)) ++ PADDING_VEC).take(SEQUENCE_LEN)
private def vecToTensor(vec: List[Array[Float]]): Tensor[Float] = Utils.concatArray(vec).resize(SEQUENCE_LEN, DIMENSIONS)
private def docToTensor(s: String): Tensor[Float] = vecToTensor(docToVecs(s))
// TODO: Map to category names
private def indexToCategoryName(idx: Int): String = {
CATEGORIES(idx)
}
def initModel(jftPath: String, bigdlPath: String, bigdlWeightPath: String, top: Int): Unit = {
jft = new JFastText()
jft.loadModel(jftPath)
model = Module.loadModule(bigdlPath, bigdlWeightPath)
this.top = top
}
def classifyString(s: String): java.util.List[String] = model.forward(docToTensor(s))
.asInstanceOf[Tensor[Float]].storage().array()
.zipWithIndex.sortBy(_._1.abs)
.map(_._2).toList.map(indexToCategoryName).filter(_.nonEmpty).take(top).asJava
// Test
def main(argv: Array[String]): Unit = {
val tests=Array("如何开发票", "azure vm", "密码重置")
tests.foreach(
s => println(s"classify($s) => ${classifyString(s)}")
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment