Created
July 31, 2024 13:31
-
-
Save makenowjust/28bbb8384354262ec9609a05e4c88b4d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// kv.scala - an implementation of the Kearns & Vazirani algorithm | |
import scala.annotation.tailrec | |
import scala.collection.mutable | |
import scala.compiletime.ops.double | |
/** A type for states. */ | |
type Q = String | |
/** A deterministic finite-state automaton (DFA). */ | |
final case class DFA( | |
initial: Q, | |
acceptSet: Set[Q], | |
transition: Map[(Q, Char), Q] | |
): | |
def run(w: String): Boolean = | |
acceptSet(w.foldLeft(initial)((q, c) => transition((q, c)))) | |
/** A system under learning (SUL). */ | |
final case class SUL( | |
alphabet: Set[Char], | |
membershipFunc: String => Boolean, | |
equivalenceFunc: DFA => Option[String] | |
): | |
private val membershipCache = mutable.Map.empty[String, Boolean] | |
/** Ask a `membership` query with the input string `w`. */ | |
def membership(w: String): Boolean = | |
membershipCache.getOrElseUpdate(w, membershipFunc(w)) | |
private var equivalenceCount = 0 | |
/** Ask an `equivalence` query with the hypothesis DFA `h`. */ | |
def equivalence(h: DFA): Option[String] = | |
equivalenceCount += 1 | |
equivalenceFunc(h) | |
/** Returns the string of statistications of the SUL. */ | |
def stats: String = | |
s"""|#MEMBER = ${membershipCache.size} | |
|#EQUIV = ${equivalenceCount}""".stripMargin | |
/** A shortcut for the `alphebat` of the SUL. */ | |
def alphabet(using SUL): Set[Char] = | |
summon[SUL].alphabet | |
/** A shortcut for a `membership` query of the SUL. */ | |
def membership(w: String)(using SUL): Boolean = | |
summon[SUL].membership(w) | |
/** A shortcut for an `equivalence` query of the SUL. */ | |
def equivalence(h: DFA)(using SUL): Option[String] = | |
summon[SUL].equivalence(h) | |
/** A node of discrimination tree. */ | |
enum DiscriminationNode: | |
case Node(label: String, one: DiscriminationNode, zero: DiscriminationNode) | |
case Leaf(state: Q) | |
/** Returns the set of access strings (state strings in leaves). */ | |
def accessSet: Set[Q] = this match | |
case Node(_, one, zero) => one.accessSet ++ zero.accessSet | |
case Leaf(state) => Set(state) | |
/** Returns the access string corresponding to the given prefix `w`. */ | |
def sift(w: String)(using SUL): Q = this match | |
case Node(label, one, zero) => | |
if membership(w ++ label) then one.sift(w) | |
else zero.sift(w) | |
case Leaf(state) => state | |
/** Returns the hypothesis DFA of this. */ | |
def toDFA(using SUL): DFA = | |
val acceptSet = Set.newBuilder[Q] | |
val transition = Map.newBuilder[(Q, Char), Q] | |
for q <- accessSet do | |
if membership(q) then acceptSet.addOne(q) | |
for c <- alphabet do transition.addOne((q, c) -> sift(q ++ c.toString)) | |
DFA("", acceptSet.result(), transition.result()) | |
/** A discrimination tree. */ | |
final case class DiscriminationTree( | |
root: DiscriminationNode.Node, | |
paths: Map[Q, Seq[Boolean]] | |
): | |
import DiscriminationNode.* | |
/** A shortcut for `root.sift(w)`. */ | |
def sift(w: String)(using SUL): Q = | |
root.sift(w) | |
/** A shortcut for `root.toDFA`. */ | |
def toDFA(using SUL): DFA = | |
root.toDFA | |
/** Returns the distinguish string of two access strings `q1` and `q2`. */ | |
def distinguish(q1: Q, q2: Q): String = | |
@tailrec | |
def loop( | |
node: DiscriminationNode, | |
p1: Seq[Boolean], | |
p2: Seq[Boolean] | |
): String = | |
node match | |
case Node(label, one, zero) => | |
if p1.head == p2.head then | |
val next = if p1.head then one else zero | |
loop(next, p1.tail, p2.tail) | |
else label | |
case Leaf(_) => sys.error("unreachable") | |
loop(root, paths(q1), paths(q2)) | |
/** Returns a new tree updated with the given counterexample string `ce`. */ | |
def update(ce: String, h: DFA)(using SUL): DiscriminationTree = | |
@tailrec | |
def split(j: Int, tq0: Q, hq0: Q): (Q, Q, Q, String, Char) = | |
val tq = sift(ce.substring(0, j + 1)) | |
val hq = h.transition((hq0, ce.charAt(j))) | |
if hq != tq then (tq0, tq, hq, ce.substring(0, j), ce.charAt(j)) | |
else split(j + 1, tq, hq) | |
val (tq0, tq, hq, nq, c) = split(0, sift(""), h.initial) | |
val d = distinguish(tq, hq) | |
val cd = c.toString ++ d | |
val accepted = membership(tq0 ++ cd) | |
val (oneQ, zeroQ) = if accepted then (tq0, nq) else (nq, tq0) | |
val p = paths(tq0) | |
def replace(node: DiscriminationNode, p: Seq[Boolean]): Node = | |
node match | |
case Node(label, one, zero) => | |
if p.head then | |
val newNode = replace(one, p.tail) | |
Node(label, newNode, zero) | |
else | |
val newNode = replace(zero, p.tail) | |
Node(label, one, newNode) | |
case Leaf(state) => | |
Node(cd, Leaf(oneQ), Leaf(zeroQ)) | |
val newRoot = replace(root, p) | |
val newPaths = paths ++ Map(tq0 -> (p :+ accepted), nq -> (p :+ !accepted)) | |
DiscriminationTree(newRoot, newPaths) | |
/** Run the Kearns & Vazirani algorithm. */ | |
def learn(teacher: SUL): DFA = | |
import DiscriminationNode.* | |
given SUL = teacher | |
val accepted = membership("") | |
val h0 = DFA( | |
"", | |
Option.when(accepted)("").toSet, | |
alphabet.iterator.map(c => (("", c) -> "")).toMap | |
) | |
equivalence(h0) match | |
case None => h0 | |
case Some(ce0) => | |
val (one, zero) = | |
if accepted then (Leaf(""), Leaf(ce0)) else (Leaf(ce0), Leaf("")) | |
val root: Node = Node("", one, zero) | |
val paths = Map("" -> Seq(accepted), ce0 -> Seq(!accepted)) | |
@tailrec | |
def loop(tree: DiscriminationTree): DFA = | |
val h = tree.toDFA | |
equivalence(h) match | |
case None => h | |
case Some(ce) => | |
loop(tree.update(ce, h)) | |
loop(DiscriminationTree(root, paths)) | |
@main | |
def main(): Unit = | |
val membership = | |
(w: String) => w.contains("0") && w.count(_ == '1') % 4 == 3 | |
val sul = SUL( | |
Set('0', '1'), | |
membership, | |
h => (0 to 256).map(_.toBinaryString).find(w => membership(w) != h.run(w)) | |
) | |
val h = learn(sul) | |
println(h) | |
println() | |
println(sul.stats) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment