Skip to content

Instantly share code, notes, and snippets.

@makenowjust
Created February 20, 2025 05:51
Show Gist options
  • Save makenowjust/53ed1b8e066952df9d4578d18d20097e to your computer and use it in GitHub Desktop.
Save makenowjust/53ed1b8e066952df9d4578d18d20097e to your computer and use it in GitHub Desktop.
// This is an implementation of the MAT* algorithm in Scala 3.
//
// The MAT* algorithm is a learning algorithm for symbolic finite automata, proposed by
// George Argyros and Loris D'Antoni, "The Learnability of Symbolic Automata"
// https://doi.org/10.1007/978-3-319-96145-3_23.
import scala.annotation.tailrec
import scala.util.hashing.MurmurHash3
/** `BoolAlg` represents an effective Boolean algebra on the domain `D`.
*
* `P` is a type of predicates on the domain `D`.
*/
trait BoolAlg[D, P]:
/** Returns the predicate that is always true. */
def `true`: P
/** Returns the predicate that is always false. */
def `false`: P
/** Returns the predicate: `p ∧ q`. */
def and(p: P, q: P): P
/** Returns the predicate: `p ∨ q`. */
def or(p: P, q: P): P
/** Returns the predicate: `¬p`. */
def not(p: P): P
/** Checks if the denotation of `p` contains `d`. */
def contains(p: P, d: D): Boolean
/** Returns a witness data of the predicate `p` if it exists. */
def witness(p: P): Option[D]
/** `Interval` is a closed interval `[left, right]` of integers. */
final class Interval(val left: Int, val right: Int):
/** Checks if the interval contains the given integer `n`. */
def contains(n: Int): Boolean = left <= n && n <= right
/** Checks if the interval overlaps with the given interval `that`. */
def overlaps(that: Interval): Boolean =
val l = math.max(left, that.left)
val r = math.min(right, that.right)
l <= r
/** Checks if the interval is contiguous with the given interval `that`. */
def contiguous(that: Interval): Boolean =
val l = math.max(left, that.left)
val r = math.min(right, that.right)
l == r + 1
/** Computes the union of the interval with the given interval `that` if it can be represented as a single interval.
*/
def unionOf(that: Interval): Option[Interval] =
Option.when(overlaps(that) || contiguous(that)):
val l = math.min(left, that.left)
val r = math.max(right, that.right)
Interval(l, r)
override def equals(that: Any): Boolean = that match
case that: Interval => left == that.left && right == that.right
case _ => false
override def hashCode(): Int =
var hash = "Interval".##
hash = MurmurHash3.mix(hash, left.##)
hash = MurmurHash3.mix(hash, right.##)
MurmurHash3.finalizeHash(hash, 2)
override def toString(): String = s"Interval($left, $right)"
object Interval:
/** Returns a new interval `[left, right]`. If `left > right`, an exception is thrown. */
def apply(left: Int, right: Int): Interval =
require(left <= right, s"Invalid interval: $left > $right")
new Interval(left, right)
/** `IntervalSet` is a set of intervals. */
final class IntervalSet(val intervals: IndexedSeq[Interval]):
/** Check if the interval set is empty. */
def isEmpty: Boolean = intervals.isEmpty
/** Checks if the interval set contains the given integer `n`. */
def contains(n: Int): Boolean =
val index = intervals.search(Interval(n, n))(Ordering.by(_.right)).insertionPoint
index < intervals.length && intervals(index).contains(n)
/** Computes the union of the interval set with the given interval set `that`. */
infix def union(that: IntervalSet): IntervalSet =
IntervalSet.from(intervals ++ that.intervals)
/** Computes the intersection of the interval set with the given interval set `that`. */
infix def intersect(that: IntervalSet): IntervalSet =
(complement union that.complement).complement
/** Computes the complement of the interval set. */
def complement: IntervalSet =
if intervals.isEmpty then IntervalSet.universal
else
val complemented = IndexedSeq.newBuilder[Interval]
if intervals.head.left != Int.MinValue then complemented += Interval(Int.MinValue, intervals.head.left - 1)
for i <- 0 until intervals.length - 1 do
val left = intervals(i).right
val right = intervals(i + 1).left
complemented += Interval(left + 1, right - 1)
if intervals.last.right != Int.MaxValue then complemented += Interval(intervals.last.right + 1, Int.MaxValue)
new IntervalSet(complemented.result())
override def equals(that: Any): Boolean = that match
case that: IntervalSet => intervals == that.intervals
case _ => false
override def hashCode(): Int =
val hash = "IntervalSet".##
MurmurHash3.mixLast(hash, intervals.##)
override def toString(): String = intervals.mkString("IntervalSet(", ", ", ")")
object IntervalSet:
/** The empty interval set. */
val empty: IntervalSet = new IntervalSet(IndexedSeq.empty)
/** The universal interval set. */
val universal: IntervalSet = new IntervalSet(IndexedSeq(Interval(Int.MinValue, Int.MaxValue)))
/** Returns a new interval set from the given intervals. */
def apply(intervals: Interval*): IntervalSet = from(intervals)
/** Returns a new interval set from the given iterator of intervals. */
def from(intervals: IterableOnce[Interval]): IntervalSet =
val sorted = intervals.iterator.toSeq.sortBy(i => (i.left, i.right))
if sorted.isEmpty then empty
else
val merged = IndexedSeq.newBuilder[Interval]
var current = sorted.head
for interval <- sorted.tail do
current.unionOf(interval) match
case Some(nextCurrent) => current = nextCurrent
case None =>
merged += current
current = interval
merged += current
new IntervalSet(merged.result())
given boolAlg: BoolAlg[Int, IntervalSet] with
def `true`: IntervalSet = IntervalSet.universal
def `false`: IntervalSet = IntervalSet.empty
def and(p: IntervalSet, q: IntervalSet): IntervalSet = p intersect q
def or(p: IntervalSet, q: IntervalSet): IntervalSet = p union q
def not(p: IntervalSet): IntervalSet = p.complement
def contains(p: IntervalSet, d: Int): Boolean = p.contains(d)
def witness(p: IntervalSet): Option[Int] =
Option.when(!p.isEmpty)(p.intervals.head.left)
/** `Membership` represents a membership query. */
trait Membership[A]:
/** Checks if the given input is a member of the target language. */
def apply(a: A): Boolean
/** `Learner` represents a Boolean algebra learner.
*
* This trait takes three type parameters:
*
* - `L` is a type of learner instance.
* - `A` is a type of input data.
* - `H` is a type of hypothesis model.
*/
trait Learner[L, A, H]:
/** Returns a new learner instance. */
def create(using Membership[A]): L
/** Returns a learner updated with the given cex (counterexample). */
def update(learner: L, cex: A)(using Membership[A]): L
/** Returns the hypothesis model conjected by the learner. */
def conject(learner: L)(using Membership[A]): H
object Learner:
/** Learns a hypothesis model from the given membership query and equivalence query. */
def learn[L, A, H](mq: Membership[A], eq: (H) => Option[A])(using L: Learner[L, A, H]): H =
given Membership[A] = mq
var learner = L.create(using mq)
var cex: Option[A] = eq(L.conject(learner))
while cex.isDefined do
learner = L.update(learner, cex.get)
cex = eq(L.conject(learner))
L.conject(learner)
/** `IntervalSetLearner` is a learner for the `IntervalSet` Boolean algebra. */
final case class IntervalSetLearner(posExampleSet: Set[Int], negExampleSet: Set[Int]):
/** Returns a new learner updated with the given counterexample `cex`. */
def update(cex: Int)(using mq: Membership[Int]): IntervalSetLearner =
if mq(cex) then copy(posExampleSet = posExampleSet + cex)
else copy(negExampleSet = negExampleSet + cex)
/** Returns the hypothesis model conjected by the learner. */
def conject(): IntervalSet =
if posExampleSet.isEmpty then IntervalSet.empty
else if negExampleSet.isEmpty then IntervalSet.universal
else if posExampleSet.size < negExampleSet.size then
posExampleSet.foldLeft(IntervalSet.empty)((l, r) => l union IntervalSet(Interval(r, r)))
else negExampleSet.foldLeft(IntervalSet.empty)((l, r) => l union IntervalSet(Interval(r, r))).complement
object IntervalSetLearner:
/** The empty interval set learner. */
val empty: IntervalSetLearner = IntervalSetLearner(Set.empty, Set.empty)
given learner: Learner[IntervalSetLearner, Int, IntervalSet] with
def create(using Membership[Int]): IntervalSetLearner = IntervalSetLearner.empty
def update(learner: IntervalSetLearner, cex: Int)(using mq: Membership[Int]): IntervalSetLearner =
learner.update(cex)
def conject(learner: IntervalSetLearner)(using Membership[Int]): IntervalSet =
learner.conject()
/** `Sfa` represents a symbolic finite automaton.
*
* In this implementation, SFAs are assumed to be deterministic and finite.
*/
final case class Sfa[S, P](
initialState: S,
acceptStateSet: Set[S],
transitionFunction: Map[S, Map[P, S]]
):
/** Computes the next state from the given state and input data. */
def transition[A](state: S, char: A)(using P: BoolAlg[A, P]): Option[S] =
val edgeMap = transitionFunction(state)
edgeMap.find((p, _) => P.contains(p, char)).map(_._2)
/** Computes the next state from the given state and word. */
def transitions[A](state: S, word: Seq[A])(using P: BoolAlg[A, P]): Option[S] =
word.foldLeft(Option(state))((state, char) => state.flatMap(transition(_, char)))
/** `Prefix` is a prefix of a word. */
type Prefix[A] = Seq[A]
/** `Suffix` is a suffix of a word. */
type Suffix[A] = Seq[A]
/** `CTree` is a classification tree. */
enum CTree[A]:
case Leaf(prefix: Prefix[A])
case Node(suffix: Suffix[A], trueBranch: CTree[A], falseBranch: CTree[A])
/** Returns the set of leaf nodes. */
def leafSet: Set[Prefix[A]] = this match
case Leaf(prefix) => Set(prefix)
case Node(suffix, trueBranch, falseBranch) =>
trueBranch.leafSet ++ falseBranch.leafSet
/** Computes the leaf node that the given word belongs to. */
@tailrec
final def sift(word: Prefix[A])(using mq: Membership[Seq[A]]): Seq[A] = this match
case Leaf(prefix) => prefix
case Node(suffix, trueBranch, falseBranch) =>
val branch =
if mq(word ++ suffix) then trueBranch
else falseBranch
branch.sift(word)
/** Returns a new classification tree by splitting the leaf node with given values. */
final def split(oldLeaf: Prefix[A], newLeaf: Prefix[A], newSuffix: Suffix[A])(using
mq: Membership[Seq[A]]
): CTree[A] =
this match
case Leaf(leaf) =>
assert(leaf == oldLeaf, s"Invalid prefix: $leaf != $oldLeaf")
if mq(oldLeaf ++ newSuffix) then Node(newSuffix, Leaf(oldLeaf), Leaf(newLeaf))
else Node(newSuffix, Leaf(newLeaf), Leaf(oldLeaf))
case Node(suffix, trueBranch, falseBranch) =>
if mq(oldLeaf ++ suffix) then Node(suffix, trueBranch.split(oldLeaf, newLeaf, newSuffix), falseBranch)
else Node(suffix, trueBranch, falseBranch.split(oldLeaf, newLeaf, newSuffix))
/** `SfaLearner` is a learner for symbolic finite automata. */
final case class SfaLearner[L, A](
tree: CTree[A],
acceptMap: Map[Prefix[A], Boolean],
guardLearnerMap: Map[(Prefix[A], Prefix[A]), L]
):
/** Returns a membership query for the learner of the given transition. */
private def membership(leaf1: Prefix[A], leaf2: Prefix[A])(using Membership[Seq[A]]): Membership[A] =
new Membership[A]:
def apply(a: A): Boolean = tree.sift(leaf1 ++ Seq(a)) == leaf2
/** Splits the classification tree and updates the learner. */
private def split[P](oldLeaf: Prefix[A], newLeaf: Prefix[A], newSuffix: Suffix[A])(using
mq: Membership[Seq[A]],
L: Learner[L, A, P]
): SfaLearner[L, A] =
println(s"split($oldLeaf, $newLeaf, $newSuffix)")
val newTree = tree.split(oldLeaf, newLeaf, newSuffix)
val newAcceptMap = acceptMap ++ Map(newLeaf -> mq(newLeaf))
val newLeafPairs = newTree.leafSet.iterator.flatMap(leaf => Iterator((newLeaf, leaf), (leaf, newLeaf)))
val oldLeafPairs = tree.leafSet.iterator.map(leaf => (leaf, oldLeaf))
val newGuardLearnerMap = guardLearnerMap ++ (newLeafPairs ++ oldLeafPairs).map((leaf1, leaf2) =>
(leaf1, leaf2) -> L.create(using membership(leaf1, leaf2))
)
SfaLearner(newTree, newAcceptMap, newGuardLearnerMap)
/** Make the guards complete for the given source leaf. */
private def makeGuardsComplete[P](
srcLeaf: Prefix[A]
)(using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): (Boolean, SfaLearner[L, A]) =
println(s"makeGuardsComplete($srcLeaf)")
val leafSet = tree.leafSet
var newGuardLearnerMap = guardLearnerMap
var updated = false
var continue = true
while continue do
val guards = leafSet.map: leaf =>
L.conject(newGuardLearnerMap(srcLeaf, leaf))(using membership(srcLeaf, leaf))
val completePred = P.not(guards.foldLeft(P.`false`)(P.or(_, _)))
P.witness(completePred) match
case None => continue = false
case Some(cex) =>
val tgtLeaf = tree.sift(srcLeaf ++ Seq(cex))
val learner = newGuardLearnerMap((srcLeaf, tgtLeaf))
val newLearner = L.update(learner, cex)(using membership(srcLeaf, tgtLeaf))
newGuardLearnerMap += (srcLeaf, tgtLeaf) -> newLearner
updated = true
(updated, copy(guardLearnerMap = newGuardLearnerMap))
/** Make the guards deterministic for the given source leaf. */
private def makeGuardsDeterministic[P](
srcLeaf: Prefix[A]
)(using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): (Boolean, SfaLearner[L, A]) =
println(s"makeGuardsDeterministic($srcLeaf)")
val leafSet = tree.leafSet
var newGuardLearnerMap = guardLearnerMap
var updated = false
var continue = true
while continue do
var updatedLocal = false
for leaf1 <- leafSet; leaf2 <- leafSet; if leaf1 != leaf2 do
val guard1 = L.conject(newGuardLearnerMap(srcLeaf, leaf1))(using membership(srcLeaf, leaf1))
val guard2 = L.conject(newGuardLearnerMap(srcLeaf, leaf2))(using membership(srcLeaf, leaf2))
val deterministicPred = P.and(guard1, guard2)
P.witness(deterministicPred) match
case None => ()
case Some(cex) =>
for (leaf, guard) <- Seq((leaf1, guard1), (leaf2, guard2)) do
if membership(srcLeaf, leaf)(cex) != P.contains(guard, cex) then
val learner = newGuardLearnerMap((srcLeaf, leaf))
val newLearner = L.update(learner, cex)(using membership(srcLeaf, leaf))
newGuardLearnerMap += (srcLeaf, leaf) -> newLearner
updatedLocal = true
updated = true
if !updatedLocal then continue = false
(updated, copy(guardLearnerMap = newGuardLearnerMap))
/** Iterates `makeGuardsComplete` and `makeGuardsDeterministic` until the guards are complete and deterministic. */
private def makeGuardsCompleteAndDeterministic[P](
srcLeaf: Prefix[A]
)(using Membership[Seq[A]], Learner[L, A, P], BoolAlg[A, P]): SfaLearner[L, A] =
println(s"makeGuardsCompleteAndDeterministic($srcLeaf)")
var learner = this
var isComplete = false
var isDeterministic = false
while !isComplete || !isDeterministic do
val result1 = learner.makeGuardsComplete(srcLeaf)
learner = result1._2
val result2 = learner.makeGuardsDeterministic(srcLeaf)
learner = result2._2
isComplete = !result1._1
isDeterministic = !result2._1
learner
/** Returns the index of the breakpoint in the counterexample. */
def analyzeCex[P](hypothesis: Sfa[Prefix[A], P], cex: Seq[A])(using mq: Membership[Seq[A]], P: BoolAlg[A, P]): Int =
var expected = mq(cex)
var low = 0
var high = cex.length
while high - low > 1 do
val mid = (high - low) / 2 + low
val state = hypothesis.transitions(hypothesis.initialState, cex.slice(0, mid)).get
val word = state ++ cex.slice(mid, cex.length)
val actual = mq(word)
if actual == expected then low = mid
else high = mid
low
/** Returns a new learner updated with the given counterexample `cex`. */
def update[P](cex: Seq[A])(using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): SfaLearner[L, A] =
println(s"update($cex)")
val hypothesis = conject[P]()
val breakpoint = analyzeCex(hypothesis, cex)
val srcLeaf = hypothesis.transitions(hypothesis.initialState, cex.slice(0, breakpoint)).get
val tgtWord = srcLeaf ++ Seq(cex(breakpoint))
val tgtLeaf = tree.sift(tgtWord)
val guard = L.conject(guardLearnerMap((srcLeaf, tgtLeaf)))(using membership(srcLeaf, tgtLeaf))
if P.contains(guard, cex(breakpoint)) then
val newSuffix = cex.slice(breakpoint + 1, cex.length)
var newLearner = split(tgtLeaf, tgtWord, newSuffix)
for leaf <- newLearner.tree.leafSet do newLearner = newLearner.makeGuardsCompleteAndDeterministic(leaf)
newLearner
else
val newTgtLeafGuard =
L.update(guardLearnerMap((srcLeaf, tgtLeaf)), cex(breakpoint))(using membership(srcLeaf, tgtLeaf))
val tgtState = hypothesis.transition(srcLeaf, cex(breakpoint)).get
val newTgtStateGuard =
L.update(guardLearnerMap((srcLeaf, tgtState)), cex(breakpoint))(using membership(srcLeaf, tgtState))
val newLearner = copy(guardLearnerMap =
guardLearnerMap ++ Map((srcLeaf, tgtLeaf) -> newTgtLeafGuard, (srcLeaf, tgtState) -> newTgtStateGuard)
)
newLearner.makeGuardsCompleteAndDeterministic(srcLeaf)
/** Returns the hypothesis SFA conjected by the learner. */
def conject[P]()(using mq: Membership[Seq[A]], L: Learner[L, A, P]): Sfa[Prefix[A], P] =
val initialState = Seq.empty
val acceptStateSet = acceptMap.keySet.filter(acceptMap(_))
val leafSet = tree.leafSet
val transitionFunction = tree.leafSet.iterator.map: srcLeaf =>
val guardMap = leafSet.iterator.map: tgtLeaf =>
val guard = L.conject(guardLearnerMap(srcLeaf, tgtLeaf))(using membership(srcLeaf, tgtLeaf))
guard -> tgtLeaf
srcLeaf -> guardMap.toMap
Sfa(initialState, acceptStateSet, transitionFunction.toMap)
object SfaLearner:
/** Returns an empty SFA learner. */
def empty[L, A, P](using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): SfaLearner[L, A] =
SfaLearner(
CTree.Leaf(Seq.empty),
Map(Seq.empty -> mq(Seq.empty)),
Map((Seq.empty, Seq.empty) -> L.create(using (_ => true)))
).makeGuardsCompleteAndDeterministic(Seq.empty)
given learner[L, A, P](using L: Learner[L, A, P], P: BoolAlg[A, P]): Learner[SfaLearner[L, A], Seq[A], Sfa[Seq[A], P]]
with
def create(using Membership[Seq[A]]): SfaLearner[L, A] = SfaLearner.empty
def update(learner: SfaLearner[L, A], cex: Seq[A])(using Membership[Seq[A]]): SfaLearner[L, A] =
learner.update(cex)
def conject(learner: SfaLearner[L, A])(using Membership[Seq[A]]): Sfa[Seq[A], P] =
learner.conject()
/** Creates an equivalence query from the given membership query and finite alphabet. */
def equivalence[A, P](
mq: Membership[Seq[A]],
finiteAlphabet: Set[A],
minWordLength: Int = 10,
maxWordLength: Int = 100,
numWords: Int = 100,
randomSeed: Long = 0L
)(using BoolAlg[A, P]): (Sfa[Prefix[A], P]) => Option[Seq[A]] =
val alphabetIndexedSeq = finiteAlphabet.toIndexedSeq
(sfa: Sfa[Prefix[A], P]) =>
println(sfa)
val rand = util.Random(randomSeed)
util.boundary:
for i <- 0 until numWords do
val size = rand.between(minWordLength, maxWordLength + 1)
var word = Seq.empty[A]
var state = sfa.initialState
for j <- 0 until size do
val char = alphabetIndexedSeq(rand.nextInt(alphabetIndexedSeq.size))
word :+= char
state = sfa.transition(state, char).get
if sfa.acceptStateSet.contains(state) != mq(word) then util.boundary.break(Some(word))
None
val mq = new Membership[Seq[Int]]:
def apply(word: Seq[Int]): Boolean =
val n0 = word.count(_ == 0)
val n1 = word.count(_ == 2)
n0 < 3 && n1 < 3 && n0 == n1
val eq = SfaLearner.equivalence[Int, IntervalSet](mq, Set(0, 1, 2))
val sfa = Learner.learn[SfaLearner[IntervalSetLearner, Int], Seq[Int], Sfa[Prefix[Int], IntervalSet]](mq, eq)
println(sfa)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment