Created
May 16, 2018 13:50
-
-
Save windoze/ee3dae3015a4e40d4929b2ef5b05c427 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cn.azure.chatbot.classifier | |
import com.github.jfasttext.JFastText | |
import com.huaban.analysis.jieba.JiebaSegmenter | |
import com.huaban.analysis.jieba.JiebaSegmenter.SegMode | |
import com.intel.analytics.bigdl.nn.Module | |
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} | |
import com.intel.analytics.bigdl.tensor.Tensor | |
import org.slf4j.LoggerFactory | |
import scala.collection.JavaConversions._ | |
import scala.collection.JavaConverters._ | |
object Classifier { | |
private val log = LoggerFactory.getLogger("Classifier") | |
private val SEQUENCE_LEN=500 | |
private val DIMENSIONS=300 | |
private val PADDING_VEC: List[Array[Float]]=List.fill(SEQUENCE_LEN)(Array.fill(DIMENSIONS)(0)) | |
// TODO: | |
private val CATEGORIES = Array("", "CCS", "IoT", "Resource health", "SQL Database", "active-directory", "analysis-services", "api-management", "app-service", "application-gateway", "automation", "azure-portal", "azure-resource", "azure-resource-manager", "backup", "batch", "billing", "cdn", "cosmos-db", "downloads", "languages", "machine-learning", "multiple", "mysql", "open-resource", "scheduler", "security", "site-recovery", "storage", "virtual-machines", "virtual-network") | |
private val cutter = new JiebaSegmenter() | |
private var jft: JFastText = _ | |
// = { | |
// val j = new JFastText() | |
// j.loadModel("cc.zh.300.bin") | |
// j | |
// } | |
//private val model = Module.loadModule[Float]("./faqmodel.bigdl", "faqmodel.bin") | |
private var model: AbstractModule[Activity, Activity, Float] = _ | |
private var top: Int = 10 | |
private def docToWords(s: String): List[String] = { | |
val ret=cutter.process(s, SegMode.SEARCH).toList.map(_.word) | |
log.debug(s"docToWords($s) => $ret") | |
ret | |
} | |
private def wordsToVecs(words: List[String]) = { | |
val ret = words.map(jft.getWordVector(_).asScala.map(_.toFloat).toArray[Float]) | |
log.debug(s"wordsToVec('$words') => $ret") | |
ret | |
} | |
private def docToVecs(s: String): List[Array[Float]] = (wordsToVecs(docToWords(s)) ++ PADDING_VEC).take(SEQUENCE_LEN) | |
private def vecToTensor(vec: List[Array[Float]]): Tensor[Float] = Utils.concatArray(vec).resize(SEQUENCE_LEN, DIMENSIONS) | |
private def docToTensor(s: String): Tensor[Float] = vecToTensor(docToVecs(s)) | |
// TODO: Map to category names | |
private def indexToCategoryName(idx: Int): String = { | |
CATEGORIES(idx) | |
} | |
def initModel(jftPath: String, bigdlPath: String, bigdlWeightPath: String, top: Int): Unit = { | |
jft = new JFastText() | |
jft.loadModel(jftPath) | |
model = Module.loadModule(bigdlPath, bigdlWeightPath) | |
this.top = top | |
} | |
def classifyString(s: String): java.util.List[String] = model.forward(docToTensor(s)) | |
.asInstanceOf[Tensor[Float]].storage().array() | |
.zipWithIndex.sortBy(_._1.abs) | |
.map(_._2).toList.map(indexToCategoryName).filter(_.nonEmpty).take(top).asJava | |
// Test | |
def main(argv: Array[String]): Unit = { | |
val tests=Array("如何开发票", "azure vm", "密码重置") | |
tests.foreach( | |
s => println(s"classify($s) => ${classifyString(s)}") | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment