Skip to content

Instantly share code, notes, and snippets.

@remysucre
Last active October 7, 2024 18:37
Show Gist options
  • Save remysucre/1788cf0153d7db240e751fb698f74d99 to your computer and use it in GitHub Desktop.
Save remysucre/1788cf0153d7db240e751fb698f74d99 to your computer and use it in GitHub Desktop.
merge-only rules in egg
use egg::*;
/// A [`Condition`] that checks if a pattern is already in the egraph.
pub struct ConditionExists<L> {
p: Pattern<L>,
}
impl<L: Language> ConditionExists<L> {
/// Create a new [`ConditionExists`] condition given a pattern.
pub fn new(p: Pattern<L>) -> Self {
ConditionExists { p }
}
}
impl<L: FromOp> ConditionExists<L> {
/// Create a ConditionExists by parsing a pattern string.
///
/// This panics if the parsing fails.
pub fn parse(a: &str) -> Self {
Self {
p: a.parse().unwrap(),
}
}
}
impl<L, N> Condition<L, N> for ConditionExists<L>
where
L: Language,
N: Analysis<L>,
{
fn check(&self, egraph: &mut EGraph<L, N>, _eclass: Id, subst: &Subst) -> bool {
let mut id_buf = vec![0.into(); self.p.ast.as_ref().len()];
lookup_pat(&mut id_buf, self.p.ast.as_ref(), egraph, subst)
}
fn vars(&self) -> Vec<Var> {
self.p.vars()
}
}
fn lookup_pat<L: Language, A: Analysis<L>>(
ids: &mut [Id],
pat: &[ENodeOrVar<L>],
egraph: &mut EGraph<L, A>,
subst: &Subst,
) -> bool {
for (i, pat_node) in pat.iter().enumerate() {
let id = match pat_node {
ENodeOrVar::Var(w) => subst[*w],
ENodeOrVar::ENode(e) => {
let n = e.clone().map_children(|child| ids[usize::from(child)]);
if let Some(i) = egraph.lookup(n) {
i
} else {
return false;
}
}
};
ids[i] = id;
}
true
}
fn main() {
let rules: Vec<Rewrite<SymbolLang, ()>> = vec![
// this merge-only rule only fires if b is in the e-graph
rewrite!("cond"; "a" => "b" if ConditionExists::parse("b")),
];
// the merge-only rule does not fire if rhs does not exist
{
let runner = Runner::default()
.with_expr(&"a".parse().unwrap())
.run(&rules);
assert_eq!(
runner.egraph.equivs(&"a".parse().unwrap(), &"b".parse().unwrap()).len(),
0
);
}
// with rhs in the e-graph, the merge-only rule fires
{
let runner = Runner::default()
.with_expr(&"a".parse().unwrap())
.with_expr(&"b".parse().unwrap())
.run(&rules);
assert_eq!(
runner.egraph.equivs(&"a".parse().unwrap(),
&"b".parse().unwrap()).len(),
1
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment