Created
February 20, 2025 05:51
-
-
Save makenowjust/53ed1b8e066952df9d4578d18d20097e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is an implementation of the MAT* algorithm in Scala 3. | |
// | |
// The MAT* algorithm is a learning algorithm for symbolic finite automata, proposed by | |
// George Argyros and Loris D'Antoni, "The Learnability of Symbolic Automata" | |
// https://doi.org/10.1007/978-3-319-96145-3_23. | |
import scala.annotation.tailrec | |
import scala.util.hashing.MurmurHash3 | |
/** `BoolAlg` represents an effective Boolean algebra on the domain `D`. | |
* | |
* `P` is a type of predicates on the domain `D`. | |
*/ | |
trait BoolAlg[D, P]: | |
/** Returns the predicate that is always true. */ | |
def `true`: P | |
/** Returns the predicate that is always false. */ | |
def `false`: P | |
/** Returns the predicate: `p ∧ q`. */ | |
def and(p: P, q: P): P | |
/** Returns the predicate: `p ∨ q`. */ | |
def or(p: P, q: P): P | |
/** Returns the predicate: `¬p`. */ | |
def not(p: P): P | |
/** Checks if the denotation of `p` contains `d`. */ | |
def contains(p: P, d: D): Boolean | |
/** Returns a witness data of the predicate `p` if it exists. */ | |
def witness(p: P): Option[D] | |
/** `Interval` is a closed interval `[left, right]` of integers. */ | |
final class Interval(val left: Int, val right: Int): | |
/** Checks if the interval contains the given integer `n`. */ | |
def contains(n: Int): Boolean = left <= n && n <= right | |
/** Checks if the interval overlaps with the given interval `that`. */ | |
def overlaps(that: Interval): Boolean = | |
val l = math.max(left, that.left) | |
val r = math.min(right, that.right) | |
l <= r | |
/** Checks if the interval is contiguous with the given interval `that`. */ | |
def contiguous(that: Interval): Boolean = | |
val l = math.max(left, that.left) | |
val r = math.min(right, that.right) | |
l == r + 1 | |
/** Computes the union of the interval with the given interval `that` if it can be represented as a single interval. | |
*/ | |
def unionOf(that: Interval): Option[Interval] = | |
Option.when(overlaps(that) || contiguous(that)): | |
val l = math.min(left, that.left) | |
val r = math.max(right, that.right) | |
Interval(l, r) | |
override def equals(that: Any): Boolean = that match | |
case that: Interval => left == that.left && right == that.right | |
case _ => false | |
override def hashCode(): Int = | |
var hash = "Interval".## | |
hash = MurmurHash3.mix(hash, left.##) | |
hash = MurmurHash3.mix(hash, right.##) | |
MurmurHash3.finalizeHash(hash, 2) | |
override def toString(): String = s"Interval($left, $right)" | |
object Interval: | |
/** Returns a new interval `[left, right]`. If `left > right`, an exception is thrown. */ | |
def apply(left: Int, right: Int): Interval = | |
require(left <= right, s"Invalid interval: $left > $right") | |
new Interval(left, right) | |
/** `IntervalSet` is a set of intervals. */ | |
final class IntervalSet(val intervals: IndexedSeq[Interval]): | |
/** Check if the interval set is empty. */ | |
def isEmpty: Boolean = intervals.isEmpty | |
/** Checks if the interval set contains the given integer `n`. */ | |
def contains(n: Int): Boolean = | |
val index = intervals.search(Interval(n, n))(Ordering.by(_.right)).insertionPoint | |
index < intervals.length && intervals(index).contains(n) | |
/** Computes the union of the interval set with the given interval set `that`. */ | |
infix def union(that: IntervalSet): IntervalSet = | |
IntervalSet.from(intervals ++ that.intervals) | |
/** Computes the intersection of the interval set with the given interval set `that`. */ | |
infix def intersect(that: IntervalSet): IntervalSet = | |
(complement union that.complement).complement | |
/** Computes the complement of the interval set. */ | |
def complement: IntervalSet = | |
if intervals.isEmpty then IntervalSet.universal | |
else | |
val complemented = IndexedSeq.newBuilder[Interval] | |
if intervals.head.left != Int.MinValue then complemented += Interval(Int.MinValue, intervals.head.left - 1) | |
for i <- 0 until intervals.length - 1 do | |
val left = intervals(i).right | |
val right = intervals(i + 1).left | |
complemented += Interval(left + 1, right - 1) | |
if intervals.last.right != Int.MaxValue then complemented += Interval(intervals.last.right + 1, Int.MaxValue) | |
new IntervalSet(complemented.result()) | |
override def equals(that: Any): Boolean = that match | |
case that: IntervalSet => intervals == that.intervals | |
case _ => false | |
override def hashCode(): Int = | |
val hash = "IntervalSet".## | |
MurmurHash3.mixLast(hash, intervals.##) | |
override def toString(): String = intervals.mkString("IntervalSet(", ", ", ")") | |
object IntervalSet: | |
/** The empty interval set. */ | |
val empty: IntervalSet = new IntervalSet(IndexedSeq.empty) | |
/** The universal interval set. */ | |
val universal: IntervalSet = new IntervalSet(IndexedSeq(Interval(Int.MinValue, Int.MaxValue))) | |
/** Returns a new interval set from the given intervals. */ | |
def apply(intervals: Interval*): IntervalSet = from(intervals) | |
/** Returns a new interval set from the given iterator of intervals. */ | |
def from(intervals: IterableOnce[Interval]): IntervalSet = | |
val sorted = intervals.iterator.toSeq.sortBy(i => (i.left, i.right)) | |
if sorted.isEmpty then empty | |
else | |
val merged = IndexedSeq.newBuilder[Interval] | |
var current = sorted.head | |
for interval <- sorted.tail do | |
current.unionOf(interval) match | |
case Some(nextCurrent) => current = nextCurrent | |
case None => | |
merged += current | |
current = interval | |
merged += current | |
new IntervalSet(merged.result()) | |
given boolAlg: BoolAlg[Int, IntervalSet] with | |
def `true`: IntervalSet = IntervalSet.universal | |
def `false`: IntervalSet = IntervalSet.empty | |
def and(p: IntervalSet, q: IntervalSet): IntervalSet = p intersect q | |
def or(p: IntervalSet, q: IntervalSet): IntervalSet = p union q | |
def not(p: IntervalSet): IntervalSet = p.complement | |
def contains(p: IntervalSet, d: Int): Boolean = p.contains(d) | |
def witness(p: IntervalSet): Option[Int] = | |
Option.when(!p.isEmpty)(p.intervals.head.left) | |
/** `Membership` represents a membership query. */ | |
trait Membership[A]: | |
/** Checks if the given input is a member of the target language. */ | |
def apply(a: A): Boolean | |
/** `Learner` represents a Boolean algebra learner. | |
* | |
* This trait takes three type parameters: | |
* | |
* - `L` is a type of learner instance. | |
* - `A` is a type of input data. | |
* - `H` is a type of hypothesis model. | |
*/ | |
trait Learner[L, A, H]: | |
/** Returns a new learner instance. */ | |
def create(using Membership[A]): L | |
/** Returns a learner updated with the given cex (counterexample). */ | |
def update(learner: L, cex: A)(using Membership[A]): L | |
/** Returns the hypothesis model conjected by the learner. */ | |
def conject(learner: L)(using Membership[A]): H | |
object Learner: | |
/** Learns a hypothesis model from the given membership query and equivalence query. */ | |
def learn[L, A, H](mq: Membership[A], eq: (H) => Option[A])(using L: Learner[L, A, H]): H = | |
given Membership[A] = mq | |
var learner = L.create(using mq) | |
var cex: Option[A] = eq(L.conject(learner)) | |
while cex.isDefined do | |
learner = L.update(learner, cex.get) | |
cex = eq(L.conject(learner)) | |
L.conject(learner) | |
/** `IntervalSetLearner` is a learner for the `IntervalSet` Boolean algebra. */ | |
final case class IntervalSetLearner(posExampleSet: Set[Int], negExampleSet: Set[Int]): | |
/** Returns a new learner updated with the given counterexample `cex`. */ | |
def update(cex: Int)(using mq: Membership[Int]): IntervalSetLearner = | |
if mq(cex) then copy(posExampleSet = posExampleSet + cex) | |
else copy(negExampleSet = negExampleSet + cex) | |
/** Returns the hypothesis model conjected by the learner. */ | |
def conject(): IntervalSet = | |
if posExampleSet.isEmpty then IntervalSet.empty | |
else if negExampleSet.isEmpty then IntervalSet.universal | |
else if posExampleSet.size < negExampleSet.size then | |
posExampleSet.foldLeft(IntervalSet.empty)((l, r) => l union IntervalSet(Interval(r, r))) | |
else negExampleSet.foldLeft(IntervalSet.empty)((l, r) => l union IntervalSet(Interval(r, r))).complement | |
object IntervalSetLearner: | |
/** The empty interval set learner. */ | |
val empty: IntervalSetLearner = IntervalSetLearner(Set.empty, Set.empty) | |
given learner: Learner[IntervalSetLearner, Int, IntervalSet] with | |
def create(using Membership[Int]): IntervalSetLearner = IntervalSetLearner.empty | |
def update(learner: IntervalSetLearner, cex: Int)(using mq: Membership[Int]): IntervalSetLearner = | |
learner.update(cex) | |
def conject(learner: IntervalSetLearner)(using Membership[Int]): IntervalSet = | |
learner.conject() | |
/** `Sfa` represents a symbolic finite automaton. | |
* | |
* In this implementation, SFAs are assumed to be deterministic and finite. | |
*/ | |
final case class Sfa[S, P]( | |
initialState: S, | |
acceptStateSet: Set[S], | |
transitionFunction: Map[S, Map[P, S]] | |
): | |
/** Computes the next state from the given state and input data. */ | |
def transition[A](state: S, char: A)(using P: BoolAlg[A, P]): Option[S] = | |
val edgeMap = transitionFunction(state) | |
edgeMap.find((p, _) => P.contains(p, char)).map(_._2) | |
/** Computes the next state from the given state and word. */ | |
def transitions[A](state: S, word: Seq[A])(using P: BoolAlg[A, P]): Option[S] = | |
word.foldLeft(Option(state))((state, char) => state.flatMap(transition(_, char))) | |
/** `Prefix` is a prefix of a word. */ | |
type Prefix[A] = Seq[A] | |
/** `Suffix` is a suffix of a word. */ | |
type Suffix[A] = Seq[A] | |
/** `CTree` is a classification tree. */ | |
enum CTree[A]: | |
case Leaf(prefix: Prefix[A]) | |
case Node(suffix: Suffix[A], trueBranch: CTree[A], falseBranch: CTree[A]) | |
/** Returns the set of leaf nodes. */ | |
def leafSet: Set[Prefix[A]] = this match | |
case Leaf(prefix) => Set(prefix) | |
case Node(suffix, trueBranch, falseBranch) => | |
trueBranch.leafSet ++ falseBranch.leafSet | |
/** Computes the leaf node that the given word belongs to. */ | |
@tailrec | |
final def sift(word: Prefix[A])(using mq: Membership[Seq[A]]): Seq[A] = this match | |
case Leaf(prefix) => prefix | |
case Node(suffix, trueBranch, falseBranch) => | |
val branch = | |
if mq(word ++ suffix) then trueBranch | |
else falseBranch | |
branch.sift(word) | |
/** Returns a new classification tree by splitting the leaf node with given values. */ | |
final def split(oldLeaf: Prefix[A], newLeaf: Prefix[A], newSuffix: Suffix[A])(using | |
mq: Membership[Seq[A]] | |
): CTree[A] = | |
this match | |
case Leaf(leaf) => | |
assert(leaf == oldLeaf, s"Invalid prefix: $leaf != $oldLeaf") | |
if mq(oldLeaf ++ newSuffix) then Node(newSuffix, Leaf(oldLeaf), Leaf(newLeaf)) | |
else Node(newSuffix, Leaf(newLeaf), Leaf(oldLeaf)) | |
case Node(suffix, trueBranch, falseBranch) => | |
if mq(oldLeaf ++ suffix) then Node(suffix, trueBranch.split(oldLeaf, newLeaf, newSuffix), falseBranch) | |
else Node(suffix, trueBranch, falseBranch.split(oldLeaf, newLeaf, newSuffix)) | |
/** `SfaLearner` is a learner for symbolic finite automata. */ | |
final case class SfaLearner[L, A]( | |
tree: CTree[A], | |
acceptMap: Map[Prefix[A], Boolean], | |
guardLearnerMap: Map[(Prefix[A], Prefix[A]), L] | |
): | |
/** Returns a membership query for the learner of the given transition. */ | |
private def membership(leaf1: Prefix[A], leaf2: Prefix[A])(using Membership[Seq[A]]): Membership[A] = | |
new Membership[A]: | |
def apply(a: A): Boolean = tree.sift(leaf1 ++ Seq(a)) == leaf2 | |
/** Splits the classification tree and updates the learner. */ | |
private def split[P](oldLeaf: Prefix[A], newLeaf: Prefix[A], newSuffix: Suffix[A])(using | |
mq: Membership[Seq[A]], | |
L: Learner[L, A, P] | |
): SfaLearner[L, A] = | |
println(s"split($oldLeaf, $newLeaf, $newSuffix)") | |
val newTree = tree.split(oldLeaf, newLeaf, newSuffix) | |
val newAcceptMap = acceptMap ++ Map(newLeaf -> mq(newLeaf)) | |
val newLeafPairs = newTree.leafSet.iterator.flatMap(leaf => Iterator((newLeaf, leaf), (leaf, newLeaf))) | |
val oldLeafPairs = tree.leafSet.iterator.map(leaf => (leaf, oldLeaf)) | |
val newGuardLearnerMap = guardLearnerMap ++ (newLeafPairs ++ oldLeafPairs).map((leaf1, leaf2) => | |
(leaf1, leaf2) -> L.create(using membership(leaf1, leaf2)) | |
) | |
SfaLearner(newTree, newAcceptMap, newGuardLearnerMap) | |
/** Make the guards complete for the given source leaf. */ | |
private def makeGuardsComplete[P]( | |
srcLeaf: Prefix[A] | |
)(using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): (Boolean, SfaLearner[L, A]) = | |
println(s"makeGuardsComplete($srcLeaf)") | |
val leafSet = tree.leafSet | |
var newGuardLearnerMap = guardLearnerMap | |
var updated = false | |
var continue = true | |
while continue do | |
val guards = leafSet.map: leaf => | |
L.conject(newGuardLearnerMap(srcLeaf, leaf))(using membership(srcLeaf, leaf)) | |
val completePred = P.not(guards.foldLeft(P.`false`)(P.or(_, _))) | |
P.witness(completePred) match | |
case None => continue = false | |
case Some(cex) => | |
val tgtLeaf = tree.sift(srcLeaf ++ Seq(cex)) | |
val learner = newGuardLearnerMap((srcLeaf, tgtLeaf)) | |
val newLearner = L.update(learner, cex)(using membership(srcLeaf, tgtLeaf)) | |
newGuardLearnerMap += (srcLeaf, tgtLeaf) -> newLearner | |
updated = true | |
(updated, copy(guardLearnerMap = newGuardLearnerMap)) | |
/** Make the guards deterministic for the given source leaf. */ | |
private def makeGuardsDeterministic[P]( | |
srcLeaf: Prefix[A] | |
)(using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): (Boolean, SfaLearner[L, A]) = | |
println(s"makeGuardsDeterministic($srcLeaf)") | |
val leafSet = tree.leafSet | |
var newGuardLearnerMap = guardLearnerMap | |
var updated = false | |
var continue = true | |
while continue do | |
var updatedLocal = false | |
for leaf1 <- leafSet; leaf2 <- leafSet; if leaf1 != leaf2 do | |
val guard1 = L.conject(newGuardLearnerMap(srcLeaf, leaf1))(using membership(srcLeaf, leaf1)) | |
val guard2 = L.conject(newGuardLearnerMap(srcLeaf, leaf2))(using membership(srcLeaf, leaf2)) | |
val deterministicPred = P.and(guard1, guard2) | |
P.witness(deterministicPred) match | |
case None => () | |
case Some(cex) => | |
for (leaf, guard) <- Seq((leaf1, guard1), (leaf2, guard2)) do | |
if membership(srcLeaf, leaf)(cex) != P.contains(guard, cex) then | |
val learner = newGuardLearnerMap((srcLeaf, leaf)) | |
val newLearner = L.update(learner, cex)(using membership(srcLeaf, leaf)) | |
newGuardLearnerMap += (srcLeaf, leaf) -> newLearner | |
updatedLocal = true | |
updated = true | |
if !updatedLocal then continue = false | |
(updated, copy(guardLearnerMap = newGuardLearnerMap)) | |
/** Iterates `makeGuardsComplete` and `makeGuardsDeterministic` until the guards are complete and deterministic. */ | |
private def makeGuardsCompleteAndDeterministic[P]( | |
srcLeaf: Prefix[A] | |
)(using Membership[Seq[A]], Learner[L, A, P], BoolAlg[A, P]): SfaLearner[L, A] = | |
println(s"makeGuardsCompleteAndDeterministic($srcLeaf)") | |
var learner = this | |
var isComplete = false | |
var isDeterministic = false | |
while !isComplete || !isDeterministic do | |
val result1 = learner.makeGuardsComplete(srcLeaf) | |
learner = result1._2 | |
val result2 = learner.makeGuardsDeterministic(srcLeaf) | |
learner = result2._2 | |
isComplete = !result1._1 | |
isDeterministic = !result2._1 | |
learner | |
/** Returns the index of the breakpoint in the counterexample. */ | |
def analyzeCex[P](hypothesis: Sfa[Prefix[A], P], cex: Seq[A])(using mq: Membership[Seq[A]], P: BoolAlg[A, P]): Int = | |
var expected = mq(cex) | |
var low = 0 | |
var high = cex.length | |
while high - low > 1 do | |
val mid = (high - low) / 2 + low | |
val state = hypothesis.transitions(hypothesis.initialState, cex.slice(0, mid)).get | |
val word = state ++ cex.slice(mid, cex.length) | |
val actual = mq(word) | |
if actual == expected then low = mid | |
else high = mid | |
low | |
/** Returns a new learner updated with the given counterexample `cex`. */ | |
def update[P](cex: Seq[A])(using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): SfaLearner[L, A] = | |
println(s"update($cex)") | |
val hypothesis = conject[P]() | |
val breakpoint = analyzeCex(hypothesis, cex) | |
val srcLeaf = hypothesis.transitions(hypothesis.initialState, cex.slice(0, breakpoint)).get | |
val tgtWord = srcLeaf ++ Seq(cex(breakpoint)) | |
val tgtLeaf = tree.sift(tgtWord) | |
val guard = L.conject(guardLearnerMap((srcLeaf, tgtLeaf)))(using membership(srcLeaf, tgtLeaf)) | |
if P.contains(guard, cex(breakpoint)) then | |
val newSuffix = cex.slice(breakpoint + 1, cex.length) | |
var newLearner = split(tgtLeaf, tgtWord, newSuffix) | |
for leaf <- newLearner.tree.leafSet do newLearner = newLearner.makeGuardsCompleteAndDeterministic(leaf) | |
newLearner | |
else | |
val newTgtLeafGuard = | |
L.update(guardLearnerMap((srcLeaf, tgtLeaf)), cex(breakpoint))(using membership(srcLeaf, tgtLeaf)) | |
val tgtState = hypothesis.transition(srcLeaf, cex(breakpoint)).get | |
val newTgtStateGuard = | |
L.update(guardLearnerMap((srcLeaf, tgtState)), cex(breakpoint))(using membership(srcLeaf, tgtState)) | |
val newLearner = copy(guardLearnerMap = | |
guardLearnerMap ++ Map((srcLeaf, tgtLeaf) -> newTgtLeafGuard, (srcLeaf, tgtState) -> newTgtStateGuard) | |
) | |
newLearner.makeGuardsCompleteAndDeterministic(srcLeaf) | |
/** Returns the hypothesis SFA conjected by the learner. */ | |
def conject[P]()(using mq: Membership[Seq[A]], L: Learner[L, A, P]): Sfa[Prefix[A], P] = | |
val initialState = Seq.empty | |
val acceptStateSet = acceptMap.keySet.filter(acceptMap(_)) | |
val leafSet = tree.leafSet | |
val transitionFunction = tree.leafSet.iterator.map: srcLeaf => | |
val guardMap = leafSet.iterator.map: tgtLeaf => | |
val guard = L.conject(guardLearnerMap(srcLeaf, tgtLeaf))(using membership(srcLeaf, tgtLeaf)) | |
guard -> tgtLeaf | |
srcLeaf -> guardMap.toMap | |
Sfa(initialState, acceptStateSet, transitionFunction.toMap) | |
object SfaLearner: | |
/** Returns an empty SFA learner. */ | |
def empty[L, A, P](using mq: Membership[Seq[A]], L: Learner[L, A, P], P: BoolAlg[A, P]): SfaLearner[L, A] = | |
SfaLearner( | |
CTree.Leaf(Seq.empty), | |
Map(Seq.empty -> mq(Seq.empty)), | |
Map((Seq.empty, Seq.empty) -> L.create(using (_ => true))) | |
).makeGuardsCompleteAndDeterministic(Seq.empty) | |
given learner[L, A, P](using L: Learner[L, A, P], P: BoolAlg[A, P]): Learner[SfaLearner[L, A], Seq[A], Sfa[Seq[A], P]] | |
with | |
def create(using Membership[Seq[A]]): SfaLearner[L, A] = SfaLearner.empty | |
def update(learner: SfaLearner[L, A], cex: Seq[A])(using Membership[Seq[A]]): SfaLearner[L, A] = | |
learner.update(cex) | |
def conject(learner: SfaLearner[L, A])(using Membership[Seq[A]]): Sfa[Seq[A], P] = | |
learner.conject() | |
/** Creates an equivalence query from the given membership query and finite alphabet. */ | |
def equivalence[A, P]( | |
mq: Membership[Seq[A]], | |
finiteAlphabet: Set[A], | |
minWordLength: Int = 10, | |
maxWordLength: Int = 100, | |
numWords: Int = 100, | |
randomSeed: Long = 0L | |
)(using BoolAlg[A, P]): (Sfa[Prefix[A], P]) => Option[Seq[A]] = | |
val alphabetIndexedSeq = finiteAlphabet.toIndexedSeq | |
(sfa: Sfa[Prefix[A], P]) => | |
println(sfa) | |
val rand = util.Random(randomSeed) | |
util.boundary: | |
for i <- 0 until numWords do | |
val size = rand.between(minWordLength, maxWordLength + 1) | |
var word = Seq.empty[A] | |
var state = sfa.initialState | |
for j <- 0 until size do | |
val char = alphabetIndexedSeq(rand.nextInt(alphabetIndexedSeq.size)) | |
word :+= char | |
state = sfa.transition(state, char).get | |
if sfa.acceptStateSet.contains(state) != mq(word) then util.boundary.break(Some(word)) | |
None | |
val mq = new Membership[Seq[Int]]: | |
def apply(word: Seq[Int]): Boolean = | |
val n0 = word.count(_ == 0) | |
val n1 = word.count(_ == 2) | |
n0 < 3 && n1 < 3 && n0 == n1 | |
val eq = SfaLearner.equivalence[Int, IntervalSet](mq, Set(0, 1, 2)) | |
val sfa = Learner.learn[SfaLearner[IntervalSetLearner, Int], Seq[Int], Sfa[Prefix[Int], IntervalSet]](mq, eq) | |
println(sfa) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment