Created
October 4, 2013 19:28
-
-
Save gatorcse/6831359 to your computer and use it in GitHub Desktop.
A basic implementation of Aho-Corasick in Scala. Not thread safe because the tree uses mutable state.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.tlo.ahocorasick | |
import scala.collection.mutable | |
import scala.collection.immutable | |
import scala.annotation.tailrec | |
sealed trait Node { | |
def children: mutable.Map[Char, LetterNode] | |
} | |
case class RootNode(children: mutable.Map[Char, LetterNode] = mutable.Map.empty) extends Node | |
case class LetterNode( | |
ch: Char, | |
children: mutable.Map[Char, LetterNode], | |
parent: Node, | |
var failure: Node, | |
var isWord: Boolean | |
) extends Node | |
object AhoCorasick { | |
/** | |
* Generates the tree that future searches will be run against | |
* @param dictionary List of words to use as the dictionary | |
* @return Root node of the search tree | |
*/ | |
def generateTreeFrom(dictionary: Set[String]): Node = { | |
// Goto | |
@tailrec | |
def addWordToTree(lastNode: Node, word: List[Char]): Unit = | |
word match { | |
case Nil => lastNode match { | |
case ln: LetterNode if !ln.isWord => ln.isWord = true | |
case _ => } | |
case letter :: tail => { | |
val next = lastNode.children | |
.getOrElseUpdate( | |
letter, | |
LetterNode(letter, mutable.Map.empty, lastNode, lastNode , isWord = false)) | |
addWordToTree(next, tail) } | |
} | |
// Generate dictionary tree | |
val root = RootNode() | |
dictionary | |
.map(_.toList) | |
.foreach(addWordToTree(root, _)) | |
// Failure | |
@tailrec | |
def failure(q: immutable.Queue[LetterNode]): Unit = if (q.nonEmpty) { | |
val (node, next) = q.dequeue | |
@tailrec | |
def failureForNode(n: Node, ch: Char): Node = { | |
n match { | |
case r: RootNode => r | |
case ln: LetterNode => { | |
ln.failure.children.get(ch) match { | |
case Some(wn) => wn | |
case None => failureForNode(ln.failure, ch) }}} | |
} | |
node.failure = failureForNode(node.failure, node.ch) | |
val q2 = (next /: node.children.values) (_ enqueue _) | |
failure(q2) | |
} | |
// start the recursive failure function | |
failure { | |
// adds root's children to the queue | |
(immutable.Queue.empty[LetterNode] /: root.children.values) (_ enqueue _) | |
} | |
root | |
} | |
/** | |
* Given a corpus of text, find all matches on the tree. | |
* @param corpus Corpus of text to match against | |
* @param dictNode Root node of tree representing the distionary | |
* @return All matching strings (one instance of each) | |
*/ | |
def searchForNamesIn(corpus: String, dictNode: Node): Set[String] = { | |
@tailrec | |
def transition(corpus: List[Char], state: Node, accumulator: Set[Node]): Set[Node] = | |
corpus match { | |
case Nil => accumulator | |
case letter :: tail => | |
state.children.get(letter) match { | |
case Some(ln) if ln.isWord => transition(tail, ln, accumulator + ln) | |
case Some(ln) if !ln.isWord => transition(tail, ln, accumulator) | |
case None => state match { | |
case r: RootNode => transition(tail, state, accumulator) | |
case ln: LetterNode => transition(corpus, ln.failure, accumulator) }} | |
} | |
@tailrec | |
def nodeToString(node: Node, stack: immutable.Stack[Char] = immutable.Stack.empty[Char]): String = | |
node match { | |
case r: RootNode => String.valueOf(stack.toArray) | |
case ln: LetterNode => nodeToString(ln.parent, stack.push(ln.ch)) | |
} | |
transition(corpus.toList, dictNode, immutable.Set.empty[Node]) | |
.map(nodeToString(_)) | |
} | |
} | |
object ExampleUsage { | |
import AhoCorasick._ | |
def main(args: Array[String]): Unit = { | |
// Generate tree | |
val names = Set("Russell Crowe", "Lance Armstrong", "Elon Musk") | |
val dict = generateTreeFrom(names) | |
println(s"generated rood node with ${dict.children.keySet.size} branches") | |
// Search for names | |
val text = "Hey Elon Musk, how are you?" | |
val found = searchForNamesIn(text, dict) | |
println(s"${found.size} names found") | |
// output | |
found.foreach(println) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment