Skip to content

Instantly share code, notes, and snippets.

@makenowjust
Created March 21, 2025 16:35
Show Gist options
  • Save makenowjust/9c463af73bcf307e341cbe6acb3524e3 to your computer and use it in GitHub Desktop.
Save makenowjust/9c463af73bcf307e341cbe6acb3524e3 to your computer and use it in GitHub Desktop.
// An implementation of the RPNI (Regular Positive and Negative Inference) algorithm.
import scala.annotation.tailrec
import scala.collection.mutable
import scala.math.Ordering.Implicits.seqOrdering
import scala.util.boundary
final case class Dfa[Q, A](
initialState: Q,
acceptStateSet: Set[Q],
transitionFunction: Map[(Q, A), Q]
)
type Prefix[A] = Seq[A]
type RedPrefix[A] = Seq[A]
type BluePrefix[A] = Seq[A]
final case class Rpni[A](
linkNodeMap: Map[Prefix[A], Rpni.LinkNode[A]]
):
def rootBlock: Rpni.Block[A] = block(Seq.empty)
@tailrec
def block(prefix: Prefix[A]): Rpni.Block[A] =
linkNodeMap(prefix) match
case Rpni.LinkNode.Node(block) => block
case Rpni.LinkNode.Link(prefix) => block(prefix)
def run(seq: Seq[A]): Option[Boolean] =
var block = rootBlock
for char <- seq do
val nextPrefix = block.children.getOrElse(
char,
throw new IllegalArgumentException("Incomplete transition")
)
block = this.block(nextPrefix)
block.isAccept
def compatible(
positiveSampleSet: Set[Seq[A]],
negativeSampleSet: Set[Seq[A]]
): Boolean =
positiveSampleSet.forall(run(_).forall(_ == true)) &&
negativeSampleSet.forall(run(_).forall(_ == false))
def tryMerge(
redPrefix: RedPrefix[A],
bluePrefix: BluePrefix[A]
): Option[Rpni[A]] =
var redBlock = block(redPrefix)
val blueBlock = block(bluePrefix)
if !redBlock.isAccept.zip(blueBlock.isAccept).forall(_ == _) then return None
var newRedBlock =
Rpni.Block(
redBlock.shortPrefix,
redBlock.prefixSet ++ blueBlock.prefixSet,
redBlock.isAccept.orElse(blueBlock.isAccept),
redBlock.children
)
var newRpni = this
boundary:
for (char, nextBluePrefix) <- blueBlock.children do
redBlock.children.get(char) match
case Some(nextRedPrefix) =>
newRpni = newRpni.tryMerge(nextRedPrefix, nextBluePrefix) match
case Some(rpni) => rpni
case None => boundary.break(None)
case None =>
newRedBlock = newRedBlock.copy(
children = newRedBlock.children.updated(char, nextBluePrefix)
)
val newLinkNodeMap = newRpni.linkNodeMap
.updated(redPrefix, Rpni.LinkNode.Node(newRedBlock))
.updated(bluePrefix, Rpni.LinkNode.Link(redPrefix))
Some(newRpni.copy(linkNodeMap = newLinkNodeMap))
def toDfa: Dfa[Seq[A], A] =
val initialState = Seq.empty[A]
val acceptStateSet = Set.newBuilder[Seq[A]]
val transitionFunction = Map.newBuilder[(Seq[A], A), Seq[A]]
val visited = mutable.Set(initialState)
val queue = mutable.Queue(initialState)
while queue.nonEmpty do
val prefix = queue.dequeue()
val block = this.block(prefix)
if block.isAccept.contains(true) then acceptStateSet.addOne(prefix)
for (char, nextPrefix) <- block.children do
val nextShortPrefix = this.block(nextPrefix).shortPrefix
transitionFunction.addOne((prefix, char) -> nextShortPrefix)
if !visited.contains(nextShortPrefix) then
visited.add(nextShortPrefix)
queue.enqueue(nextShortPrefix)
Dfa(initialState, acceptStateSet.result(), transitionFunction.result())
object Rpni:
enum LinkNode[A]:
case Node(block: Block[A])
case Link(prefix: Prefix[A])
final case class Block[A](
shortPrefix: Prefix[A],
prefixSet: Set[Prefix[A]],
isAccept: Option[Boolean],
children: Map[A, Seq[A]]
)
def build[A](
positiveSampleSet: Set[Seq[A]],
negativeSampleSet: Set[Seq[A]]
): Rpni[A] =
val blockMap = mutable.Map.empty[Prefix[A], Block[A]]
def insert(prefix: Prefix[A], isAccept: Option[Boolean]): Unit =
var block = blockMap.getOrElseUpdate(
prefix,
Block(prefix, Set(prefix), isAccept, Map.empty)
)
if !block.isAccept.zip(isAccept).forall(_ == _) then
throw new IllegalArgumentException(
"Inconsistent positive and negative samples"
)
block = block.copy(isAccept = block.isAccept.orElse(isAccept))
blockMap(prefix) = block
for sample <- positiveSampleSet do
for i <- 0 to sample.length do
val prefix = sample.take(i)
insert(prefix, Option.when(i == sample.length)(true))
if i < sample.length then
val block = blockMap(prefix)
blockMap(prefix) = block.copy(children = block.children.updated(sample(i), sample.take(i + 1)))
for sample <- negativeSampleSet do
for i <- 0 to sample.length do
val prefix = sample.take(i)
insert(prefix, Option.when(i == sample.length)(false))
if i < sample.length then
val block = blockMap(prefix)
blockMap(prefix) = block.copy(children = block.children.updated(sample(i), sample.take(i + 1)))
val linkNodeMap = blockMap.iterator
.map((prefix, block) => prefix -> LinkNode.Node(block))
.toMap
Rpni(linkNodeMap)
def run[A](positiveSampleSet: Set[Seq[A]], negativeSampleSet: Set[Seq[A]])(using
Ordering[A]
): Dfa[Seq[A], A] =
val lexOrdering = Ordering.by[Seq[A], Int](_.length).orElse(seqOrdering)
var rpni = build(positiveSampleSet, negativeSampleSet)
val redPrefixSet = mutable.SortedSet.empty[RedPrefix[A]](using lexOrdering)
val bluePrefixQueue =
mutable.PriorityQueue.empty[BluePrefix[A]](using lexOrdering.reverse)
redPrefixSet.add(Seq.empty)
val block = rpni.block(Seq.empty)
bluePrefixQueue.enqueue(block.children.values.toSeq*)
while bluePrefixQueue.nonEmpty do
println(s"red: ${redPrefixSet}")
println(s"blue: ${bluePrefixQueue}")
val bluePrefix = bluePrefixQueue.dequeue()
val updated = boundary:
for redPrefix <- redPrefixSet do
rpni.tryMerge(redPrefix, bluePrefix) match
case Some(newRpni) if newRpni.compatible(positiveSampleSet, negativeSampleSet) =>
println(s"merge: $redPrefix, $bluePrefix")
rpni = newRpni
for candidate <- rpni.block(redPrefix).children.values do
val block = rpni.block(candidate)
if !redPrefixSet.contains(block.shortPrefix) && bluePrefixQueue.forall(_ != block.shortPrefix) then
bluePrefixQueue.enqueue(block.shortPrefix)
boundary.break(true)
case _ => // Nothing to do
false
if !updated then
println(s"promote: $bluePrefix")
redPrefixSet.add(bluePrefix)
val block = rpni.block(bluePrefix)
bluePrefixQueue.enqueue(block.children.values.toSeq*)
rpni.toDfa
val dfa1 = Rpni.run(
Set(Seq(1, 1, 1), Seq(0, 0, 0), Seq(1, 1, 1, 0, 1), Seq(0, 1)),
Set(Seq(0), Seq(1), Seq(0, 0), Seq(1, 1))
)
println(dfa1)
val dfa2 = Rpni.run(
Set(Seq(0, 1, 1), Seq(1, 0, 1)),
Set(Seq(1))
)
println(dfa2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment