Created
March 21, 2025 16:35
-
-
Save makenowjust/9c463af73bcf307e341cbe6acb3524e3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// An implementation of the RPNI (Regular Positive and Negative Inference) algorithm. | |
import scala.annotation.tailrec | |
import scala.collection.mutable | |
import scala.math.Ordering.Implicits.seqOrdering | |
import scala.util.boundary | |
final case class Dfa[Q, A]( | |
initialState: Q, | |
acceptStateSet: Set[Q], | |
transitionFunction: Map[(Q, A), Q] | |
) | |
type Prefix[A] = Seq[A] | |
type RedPrefix[A] = Seq[A] | |
type BluePrefix[A] = Seq[A] | |
final case class Rpni[A]( | |
linkNodeMap: Map[Prefix[A], Rpni.LinkNode[A]] | |
): | |
def rootBlock: Rpni.Block[A] = block(Seq.empty) | |
@tailrec | |
def block(prefix: Prefix[A]): Rpni.Block[A] = | |
linkNodeMap(prefix) match | |
case Rpni.LinkNode.Node(block) => block | |
case Rpni.LinkNode.Link(prefix) => block(prefix) | |
def run(seq: Seq[A]): Option[Boolean] = | |
var block = rootBlock | |
for char <- seq do | |
val nextPrefix = block.children.getOrElse( | |
char, | |
throw new IllegalArgumentException("Incomplete transition") | |
) | |
block = this.block(nextPrefix) | |
block.isAccept | |
def compatible( | |
positiveSampleSet: Set[Seq[A]], | |
negativeSampleSet: Set[Seq[A]] | |
): Boolean = | |
positiveSampleSet.forall(run(_).forall(_ == true)) && | |
negativeSampleSet.forall(run(_).forall(_ == false)) | |
def tryMerge( | |
redPrefix: RedPrefix[A], | |
bluePrefix: BluePrefix[A] | |
): Option[Rpni[A]] = | |
var redBlock = block(redPrefix) | |
val blueBlock = block(bluePrefix) | |
if !redBlock.isAccept.zip(blueBlock.isAccept).forall(_ == _) then return None | |
var newRedBlock = | |
Rpni.Block( | |
redBlock.shortPrefix, | |
redBlock.prefixSet ++ blueBlock.prefixSet, | |
redBlock.isAccept.orElse(blueBlock.isAccept), | |
redBlock.children | |
) | |
var newRpni = this | |
boundary: | |
for (char, nextBluePrefix) <- blueBlock.children do | |
redBlock.children.get(char) match | |
case Some(nextRedPrefix) => | |
newRpni = newRpni.tryMerge(nextRedPrefix, nextBluePrefix) match | |
case Some(rpni) => rpni | |
case None => boundary.break(None) | |
case None => | |
newRedBlock = newRedBlock.copy( | |
children = newRedBlock.children.updated(char, nextBluePrefix) | |
) | |
val newLinkNodeMap = newRpni.linkNodeMap | |
.updated(redPrefix, Rpni.LinkNode.Node(newRedBlock)) | |
.updated(bluePrefix, Rpni.LinkNode.Link(redPrefix)) | |
Some(newRpni.copy(linkNodeMap = newLinkNodeMap)) | |
def toDfa: Dfa[Seq[A], A] = | |
val initialState = Seq.empty[A] | |
val acceptStateSet = Set.newBuilder[Seq[A]] | |
val transitionFunction = Map.newBuilder[(Seq[A], A), Seq[A]] | |
val visited = mutable.Set(initialState) | |
val queue = mutable.Queue(initialState) | |
while queue.nonEmpty do | |
val prefix = queue.dequeue() | |
val block = this.block(prefix) | |
if block.isAccept.contains(true) then acceptStateSet.addOne(prefix) | |
for (char, nextPrefix) <- block.children do | |
val nextShortPrefix = this.block(nextPrefix).shortPrefix | |
transitionFunction.addOne((prefix, char) -> nextShortPrefix) | |
if !visited.contains(nextShortPrefix) then | |
visited.add(nextShortPrefix) | |
queue.enqueue(nextShortPrefix) | |
Dfa(initialState, acceptStateSet.result(), transitionFunction.result()) | |
object Rpni: | |
enum LinkNode[A]: | |
case Node(block: Block[A]) | |
case Link(prefix: Prefix[A]) | |
final case class Block[A]( | |
shortPrefix: Prefix[A], | |
prefixSet: Set[Prefix[A]], | |
isAccept: Option[Boolean], | |
children: Map[A, Seq[A]] | |
) | |
def build[A]( | |
positiveSampleSet: Set[Seq[A]], | |
negativeSampleSet: Set[Seq[A]] | |
): Rpni[A] = | |
val blockMap = mutable.Map.empty[Prefix[A], Block[A]] | |
def insert(prefix: Prefix[A], isAccept: Option[Boolean]): Unit = | |
var block = blockMap.getOrElseUpdate( | |
prefix, | |
Block(prefix, Set(prefix), isAccept, Map.empty) | |
) | |
if !block.isAccept.zip(isAccept).forall(_ == _) then | |
throw new IllegalArgumentException( | |
"Inconsistent positive and negative samples" | |
) | |
block = block.copy(isAccept = block.isAccept.orElse(isAccept)) | |
blockMap(prefix) = block | |
for sample <- positiveSampleSet do | |
for i <- 0 to sample.length do | |
val prefix = sample.take(i) | |
insert(prefix, Option.when(i == sample.length)(true)) | |
if i < sample.length then | |
val block = blockMap(prefix) | |
blockMap(prefix) = block.copy(children = block.children.updated(sample(i), sample.take(i + 1))) | |
for sample <- negativeSampleSet do | |
for i <- 0 to sample.length do | |
val prefix = sample.take(i) | |
insert(prefix, Option.when(i == sample.length)(false)) | |
if i < sample.length then | |
val block = blockMap(prefix) | |
blockMap(prefix) = block.copy(children = block.children.updated(sample(i), sample.take(i + 1))) | |
val linkNodeMap = blockMap.iterator | |
.map((prefix, block) => prefix -> LinkNode.Node(block)) | |
.toMap | |
Rpni(linkNodeMap) | |
def run[A](positiveSampleSet: Set[Seq[A]], negativeSampleSet: Set[Seq[A]])(using | |
Ordering[A] | |
): Dfa[Seq[A], A] = | |
val lexOrdering = Ordering.by[Seq[A], Int](_.length).orElse(seqOrdering) | |
var rpni = build(positiveSampleSet, negativeSampleSet) | |
val redPrefixSet = mutable.SortedSet.empty[RedPrefix[A]](using lexOrdering) | |
val bluePrefixQueue = | |
mutable.PriorityQueue.empty[BluePrefix[A]](using lexOrdering.reverse) | |
redPrefixSet.add(Seq.empty) | |
val block = rpni.block(Seq.empty) | |
bluePrefixQueue.enqueue(block.children.values.toSeq*) | |
while bluePrefixQueue.nonEmpty do | |
println(s"red: ${redPrefixSet}") | |
println(s"blue: ${bluePrefixQueue}") | |
val bluePrefix = bluePrefixQueue.dequeue() | |
val updated = boundary: | |
for redPrefix <- redPrefixSet do | |
rpni.tryMerge(redPrefix, bluePrefix) match | |
case Some(newRpni) if newRpni.compatible(positiveSampleSet, negativeSampleSet) => | |
println(s"merge: $redPrefix, $bluePrefix") | |
rpni = newRpni | |
for candidate <- rpni.block(redPrefix).children.values do | |
val block = rpni.block(candidate) | |
if !redPrefixSet.contains(block.shortPrefix) && bluePrefixQueue.forall(_ != block.shortPrefix) then | |
bluePrefixQueue.enqueue(block.shortPrefix) | |
boundary.break(true) | |
case _ => // Nothing to do | |
false | |
if !updated then | |
println(s"promote: $bluePrefix") | |
redPrefixSet.add(bluePrefix) | |
val block = rpni.block(bluePrefix) | |
bluePrefixQueue.enqueue(block.children.values.toSeq*) | |
rpni.toDfa | |
val dfa1 = Rpni.run( | |
Set(Seq(1, 1, 1), Seq(0, 0, 0), Seq(1, 1, 1, 0, 1), Seq(0, 1)), | |
Set(Seq(0), Seq(1), Seq(0, 0), Seq(1, 1)) | |
) | |
println(dfa1) | |
val dfa2 = Rpni.run( | |
Set(Seq(0, 1, 1), Seq(1, 0, 1)), | |
Set(Seq(1)) | |
) | |
println(dfa2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment