Skip to content

Instantly share code, notes, and snippets.

@makenowjust
Created July 31, 2024 13:31
Show Gist options
  • Save makenowjust/28bbb8384354262ec9609a05e4c88b4d to your computer and use it in GitHub Desktop.
Save makenowjust/28bbb8384354262ec9609a05e4c88b4d to your computer and use it in GitHub Desktop.
// kv.scala - an implementation of the Kearns & Vazirani algorithm
import scala.annotation.tailrec
import scala.collection.mutable
import scala.compiletime.ops.double
/** A type for states. */
type Q = String
/** A deterministic finite-state automaton (DFA). */
final case class DFA(
initial: Q,
acceptSet: Set[Q],
transition: Map[(Q, Char), Q]
):
def run(w: String): Boolean =
acceptSet(w.foldLeft(initial)((q, c) => transition((q, c))))
/** A system under learning (SUL). */
final case class SUL(
alphabet: Set[Char],
membershipFunc: String => Boolean,
equivalenceFunc: DFA => Option[String]
):
private val membershipCache = mutable.Map.empty[String, Boolean]
/** Ask a `membership` query with the input string `w`. */
def membership(w: String): Boolean =
membershipCache.getOrElseUpdate(w, membershipFunc(w))
private var equivalenceCount = 0
/** Ask an `equivalence` query with the hypothesis DFA `h`. */
def equivalence(h: DFA): Option[String] =
equivalenceCount += 1
equivalenceFunc(h)
/** Returns the string of statistications of the SUL. */
def stats: String =
s"""|#MEMBER = ${membershipCache.size}
|#EQUIV = ${equivalenceCount}""".stripMargin
/** A shortcut for the `alphebat` of the SUL. */
def alphabet(using SUL): Set[Char] =
summon[SUL].alphabet
/** A shortcut for a `membership` query of the SUL. */
def membership(w: String)(using SUL): Boolean =
summon[SUL].membership(w)
/** A shortcut for an `equivalence` query of the SUL. */
def equivalence(h: DFA)(using SUL): Option[String] =
summon[SUL].equivalence(h)
/** A node of discrimination tree. */
enum DiscriminationNode:
case Node(label: String, one: DiscriminationNode, zero: DiscriminationNode)
case Leaf(state: Q)
/** Returns the set of access strings (state strings in leaves). */
def accessSet: Set[Q] = this match
case Node(_, one, zero) => one.accessSet ++ zero.accessSet
case Leaf(state) => Set(state)
/** Returns the access string corresponding to the given prefix `w`. */
def sift(w: String)(using SUL): Q = this match
case Node(label, one, zero) =>
if membership(w ++ label) then one.sift(w)
else zero.sift(w)
case Leaf(state) => state
/** Returns the hypothesis DFA of this. */
def toDFA(using SUL): DFA =
val acceptSet = Set.newBuilder[Q]
val transition = Map.newBuilder[(Q, Char), Q]
for q <- accessSet do
if membership(q) then acceptSet.addOne(q)
for c <- alphabet do transition.addOne((q, c) -> sift(q ++ c.toString))
DFA("", acceptSet.result(), transition.result())
/** A discrimination tree. */
final case class DiscriminationTree(
root: DiscriminationNode.Node,
paths: Map[Q, Seq[Boolean]]
):
import DiscriminationNode.*
/** A shortcut for `root.sift(w)`. */
def sift(w: String)(using SUL): Q =
root.sift(w)
/** A shortcut for `root.toDFA`. */
def toDFA(using SUL): DFA =
root.toDFA
/** Returns the distinguish string of two access strings `q1` and `q2`. */
def distinguish(q1: Q, q2: Q): String =
@tailrec
def loop(
node: DiscriminationNode,
p1: Seq[Boolean],
p2: Seq[Boolean]
): String =
node match
case Node(label, one, zero) =>
if p1.head == p2.head then
val next = if p1.head then one else zero
loop(next, p1.tail, p2.tail)
else label
case Leaf(_) => sys.error("unreachable")
loop(root, paths(q1), paths(q2))
/** Returns a new tree updated with the given counterexample string `ce`. */
def update(ce: String, h: DFA)(using SUL): DiscriminationTree =
@tailrec
def split(j: Int, tq0: Q, hq0: Q): (Q, Q, Q, String, Char) =
val tq = sift(ce.substring(0, j + 1))
val hq = h.transition((hq0, ce.charAt(j)))
if hq != tq then (tq0, tq, hq, ce.substring(0, j), ce.charAt(j))
else split(j + 1, tq, hq)
val (tq0, tq, hq, nq, c) = split(0, sift(""), h.initial)
val d = distinguish(tq, hq)
val cd = c.toString ++ d
val accepted = membership(tq0 ++ cd)
val (oneQ, zeroQ) = if accepted then (tq0, nq) else (nq, tq0)
val p = paths(tq0)
def replace(node: DiscriminationNode, p: Seq[Boolean]): Node =
node match
case Node(label, one, zero) =>
if p.head then
val newNode = replace(one, p.tail)
Node(label, newNode, zero)
else
val newNode = replace(zero, p.tail)
Node(label, one, newNode)
case Leaf(state) =>
Node(cd, Leaf(oneQ), Leaf(zeroQ))
val newRoot = replace(root, p)
val newPaths = paths ++ Map(tq0 -> (p :+ accepted), nq -> (p :+ !accepted))
DiscriminationTree(newRoot, newPaths)
/** Run the Kearns & Vazirani algorithm. */
def learn(teacher: SUL): DFA =
import DiscriminationNode.*
given SUL = teacher
val accepted = membership("")
val h0 = DFA(
"",
Option.when(accepted)("").toSet,
alphabet.iterator.map(c => (("", c) -> "")).toMap
)
equivalence(h0) match
case None => h0
case Some(ce0) =>
val (one, zero) =
if accepted then (Leaf(""), Leaf(ce0)) else (Leaf(ce0), Leaf(""))
val root: Node = Node("", one, zero)
val paths = Map("" -> Seq(accepted), ce0 -> Seq(!accepted))
@tailrec
def loop(tree: DiscriminationTree): DFA =
val h = tree.toDFA
equivalence(h) match
case None => h
case Some(ce) =>
loop(tree.update(ce, h))
loop(DiscriminationTree(root, paths))
@main
def main(): Unit =
val membership =
(w: String) => w.contains("0") && w.count(_ == '1') % 4 == 3
val sul = SUL(
Set('0', '1'),
membership,
h => (0 to 256).map(_.toBinaryString).find(w => membership(w) != h.run(w))
)
val h = learn(sul)
println(h)
println()
println(sul.stats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment