Last active
August 14, 2019 19:27
-
-
Save wernsey/b9a6c7be1ee4a15718e28527c7d58767 to your computer and use it in GitHub Desktop.
Python Datalog
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! python | |
# Version 1.0 | |
# Naive, with no negation. | |
# Also, no parser | |
class DatalogException(Exception): | |
pass | |
class Expr: | |
def __init__(self, predicate, *args): | |
self.predicate = predicate | |
self.terms = list(args) | |
def __str__(self): | |
return "%s(%s)" % (self.predicate, ",".join(self.terms)) | |
def __eq__(self, other): | |
if not isinstance(other, self.__class__): | |
return False | |
if self.predicate != other.predicate: | |
return False | |
if self.arity() != other.arity(): | |
return False | |
for i in range(len(self.terms)): | |
if self.terms[i] != other.terms[i]: | |
return False | |
return True | |
def __hash__(self): | |
h = hash(self.predicate) | |
for t in self.terms: | |
h += hash(t) | |
return h | |
def arity(self): | |
return len(self.terms) | |
@staticmethod | |
def isVariable(term): | |
return term[0].isupper() | |
def isGround(self): | |
for t in self.terms: | |
if Expr.isVariable(t): | |
return False | |
return True | |
def substitute(self, bindings): | |
that = Expr(self.predicate) | |
for t in self.terms: | |
if Expr.isVariable(t) and t in bindings: | |
that.terms.append(bindings[t]) | |
else: | |
that.terms.append(t) | |
return that | |
def unify(self, that, bindings): | |
if self.predicate != that.predicate or self.arity() != that.arity(): | |
return False | |
for i in range(self.arity()): | |
term1 = self.terms[i] | |
term2 = that.terms[i] | |
if Expr.isVariable(term1): | |
if term1 != term2: | |
if term1 not in bindings: | |
bindings[term1] = term2 | |
elif bindings[term1] != term2: | |
return False | |
elif Expr.isVariable(term2): | |
if term2 not in bindings: | |
bindings[term2] = term1 | |
elif bindings[term2] != term1: | |
return False | |
elif term1 != term2: | |
return False | |
return True | |
class Rule: | |
def __init__(self, head, *args): | |
self.head = head | |
self.body = list(args) | |
def __str__(self): | |
s = ", ".join([e.__str__() for e in self.body]) | |
return "%s :- %s" % (self.head, s) | |
def validate(self): | |
# Make sure all variables in the head are in the body as well | |
headvars = set(filter(lambda t : Expr.isVariable(t), self.head.terms)) | |
bodyvars = set() | |
for e in self.body: | |
for t in filter(lambda t : Expr.isVariable(t), e.terms): | |
bodyvars.add(t) | |
delta = headvars - bodyvars | |
if len(delta) > 0: | |
raise DatalogException("Variables in head not in body: %s" % (",".join(delta))) | |
class Binding(): | |
def __init__(self, parent=None): | |
self.dict = dict() | |
self.parent = parent | |
def __contains__(self, item): | |
if item in self.dict: | |
return True | |
elif self.parent: | |
return item in self.parent | |
else: | |
return None | |
def __getitem__(self, item): | |
if item in self.dict: | |
return self.dict[item] | |
elif self.parent: | |
return self.parent[item] | |
else: | |
return None | |
def __setitem__(self, key, value): | |
self.dict[key] = value | |
def keys(self): | |
if self.parent: | |
return set(list(self.dict.keys()) + list(self.parent.keys())) | |
return self.dict.keys() | |
def iter(self): | |
for k in self.keys(): | |
yield (k, self[k]) | |
def print(self): | |
for k,v in self.iter(): | |
print(k,v) | |
class Datalog: | |
def __init__(self): | |
self.idb = [] | |
self.edb = set() | |
def fact(self, predicate, *args): | |
f = Expr(predicate, *args) | |
if not f.isGround(): | |
raise DatalogException("Facts must be ground: %s" % (f,)) | |
self.edb.add(f) | |
return self | |
def rule(self, head, *args): | |
r = Rule(head, *args) | |
r.validate() | |
self.idb.append(r) | |
return self | |
def query(self, goals): | |
# returns a list of Bindings | |
if len(goals) == 0: | |
return [] | |
dataset = self.expand(self.edb, self.idb) | |
return self.matchGoals(goals, dataset) | |
def expand(self, facts, rules): | |
while True: | |
newFacts = set() | |
for rule in rules: | |
results = self.matchRule(facts, rule) | |
newFacts = newFacts | results | |
delta = newFacts - facts | |
if len(delta) == 0: | |
return facts | |
facts = facts | newFacts | |
def matchRule(self, facts, rule): | |
answers = self.matchGoals(rule.body, facts) | |
results = set() | |
for answer in answers: | |
derivedFact = rule.head.substitute(answer) | |
results.add(derivedFact) | |
return results | |
def matchGoals(self, goals, facts, bindings = None): | |
answers = [] | |
goal = goals[0] | |
for fact in facts: | |
if fact.predicate != goal.predicate: | |
continue | |
newBindings = Binding(bindings) | |
if fact.unify(goal, newBindings): | |
if len(goals[1:]) == 0: | |
answers.append(newBindings) | |
else: | |
answers = answers + self.matchGoals(goals[1:], facts, newBindings) | |
return answers | |
if __name__ == "__main__": | |
datalog = Datalog() | |
datalog.fact('parent', 'alice', 'bob') | |
datalog.fact('parent', 'arnold', 'bob') | |
datalog.fact('parent', 'bob', 'carol') | |
datalog.fact('parent', 'betty', 'carol') | |
datalog.fact('parent', 'carol', 'dean') | |
datalog.fact('parent', 'carol', 'dennis') | |
datalog.fact('parent', 'charles', 'dennis') | |
datalog.fact('parent', 'ben', 'catrina') | |
datalog.fact('parent', 'ben', 'charles') | |
datalog.fact('parent', 'ben', 'cherry') | |
datalog.fact('parent', 'bella', 'catrina') | |
datalog.fact('parent', 'bella', 'cherry') | |
datalog.fact('parent', 'cherry', 'donna') | |
datalog.fact('parent', 'dennis', 'edna') | |
datalog.fact('parent', 'donna', 'edna') | |
datalog.fact('parent', 'dean', 'elvis') | |
datalog.fact('parent', 'catrina', 'dana') | |
datalog.fact('parent', 'dana', 'eric') | |
datalog.fact('parent', 'alana', 'bill') | |
datalog.rule(Expr('ancestor', 'A', 'B'), Expr('parent', 'A', 'B')) | |
datalog.rule(Expr('ancestor', 'A', 'B'), Expr('ancestor', 'A', 'C'), Expr('parent', 'C', 'B')) | |
datalog.rule(Expr('descendant', 'A', 'B'), Expr('ancestor', 'B', 'A')) | |
# SGC - same generation cousin | |
datalog.rule(Expr('sgc', 'A', 'B'), Expr('parent', 'Z', 'A'), Expr('parent', 'Z', 'B')) | |
datalog.rule(Expr('sgc', 'A', 'B'), Expr('parent', 'P', 'A'), Expr('sgc', 'P', 'Q'), Expr('parent', 'Q', 'B')) | |
goals = [Expr('ancestor', 'carol', 'X')] | |
answers = datalog.query(goals) | |
print("----------------") | |
for a in answers: | |
print("answer: ", end='') | |
a.print() | |
goals = [Expr('ancestor', 'B', 'dennis'),Expr('ancestor', 'B', 'dean')] | |
answers = datalog.query(goals) | |
print("----------------") | |
for a in answers: | |
print("answer: ", end='') | |
a.print() | |
goals = [Expr('ancestor', 'B', 'dennis'),Expr('ancestor', 'B', 'donna')] | |
answers = datalog.query(goals) | |
print("----------------") | |
for a in answers: | |
print("answer: ", end='') | |
a.print() | |
goals = [Expr('descendant', 'X', 'ben')] | |
answers = datalog.query(goals) | |
print("----------------") | |
for a in answers: | |
print("answer: ", end='') | |
a.print() | |
goals = [Expr('sgc', 'dennis', 'Cousin')] | |
answers = datalog.query(goals) | |
print("----------------") | |
for a in answers: | |
print("answer: ", end='') | |
a.print() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! python | |
from typing import Dict, List, Set, AbstractSet, Sequence, Iterable, Tuple | |
# Version 1.0 | |
# Naive, with no negation. | |
# Also, no parser | |
# | |
# Type checking with mypy: | |
# $ python3 -m pip install -U mypy | |
# $ mypy datalog.py | |
class DatalogException(Exception): | |
pass | |
class Expr: | |
def __init__(self, predicate, *args): | |
self.predicate = predicate | |
self.terms = list(args) | |
def __str__(self): | |
return "%s(%s)" % (self.predicate, ",".join(self.terms)) | |
def __eq__(self, other): | |
if not isinstance(other, self.__class__): | |
return False | |
if self.predicate != other.predicate or self.arity() != other.arity(): | |
return False | |
for i in range(len(self.terms)): | |
if self.terms[i] != other.terms[i]: | |
return False | |
return True | |
def __hash__(self): | |
h = hash(self.predicate) | |
for t in self.terms: | |
h += hash(t) | |
return h | |
def arity(self): | |
return len(self.terms) | |
def isGround(self): | |
for t in self.terms: | |
if isVariable(t): | |
return False | |
return True | |
def substitute(self, binding): | |
that = Expr(self.predicate) | |
for t in self.terms: | |
if isVariable(t) and t in binding: | |
that.terms.append(binding[t]) | |
else: | |
that.terms.append(t) | |
return that | |
def unify(self, that, binding) -> bool: | |
if self.predicate != that.predicate or self.arity() != that.arity(): | |
return False | |
for i in range(self.arity()): | |
term1 = self.terms[i] | |
term2 = that.terms[i] | |
if isVariable(term1): | |
if term1 != term2: | |
if term1 not in binding: | |
binding[term1] = term2 | |
elif binding[term1] != term2: | |
return False | |
elif isVariable(term2): | |
if term2 not in binding: | |
binding[term2] = term1 | |
elif binding[term2] != term1: | |
return False | |
elif term1 != term2: | |
return False | |
return True | |
class Rule: | |
def __init__(self, head, *args): | |
self.head = head | |
self.body = list(args) | |
def __str__(self): | |
s = ", ".join([e.__str__() for e in self.body]) | |
return "%s :- %s" % (self.head, s) | |
def validate(self): | |
# Make sure all variables in the head are in the body as well | |
headvars = set(filter(lambda t : isVariable(t), self.head.terms)) | |
bodyvars = set() | |
for e in self.body: | |
for t in filter(lambda t : isVariable(t), e.terms): | |
bodyvars.add(t) | |
delta = headvars - bodyvars | |
if len(delta) > 0: | |
raise DatalogException("Variables in head not in body: %s" % (",".join(delta))) | |
class Binding(): | |
def __init__(self, parent=None): | |
self.dict = dict() | |
self.parent = parent | |
def __contains__(self, item : str): | |
if item in self.dict: | |
return True | |
elif self.parent: | |
return item in self.parent | |
else: | |
return None | |
def __getitem__(self, item : str): | |
if item in self.dict: | |
return self.dict[item] | |
elif self.parent: | |
return self.parent[item] | |
else: | |
return None | |
def __setitem__(self, key : str, value : str): | |
self.dict[key] = value | |
def keys(self) -> Iterable[str]: | |
if self.parent: | |
return set(list(self.dict.keys()) + list(self.parent.keys())) | |
return self.dict.keys() | |
def iter(self) -> Iterable[Tuple[str, str]]: | |
for k in self.keys(): | |
yield (k, self[k]) | |
def __str__(self): | |
s = ["%s=%s" % (k,v) for k,v in self.iter()] | |
return "{%s}" % "; ".join(s) | |
class Datalog: | |
def __init__(self): | |
self.idb = [] | |
self.edb = set() | |
def fact(self, predicate : str, *args : str): | |
f = Expr(predicate, *args) | |
if not f.isGround(): | |
raise DatalogException("Facts must be ground: %s" % (f,)) | |
self.edb.add(f) | |
return self | |
def rule(self, head : Expr, *args : Expr): | |
r = Rule(head, *args) | |
r.validate() | |
self.idb.append(r) | |
return self | |
def query(self, goals : Sequence[Expr]) -> List[Binding]: | |
# returns a list of Bindings | |
if len(goals) == 0: | |
return [] | |
dataset = self.expand(self.edb, self.idb) | |
return self.matchGoals(goals, dataset) | |
def expand(self, facts : Set[Expr], rules : Set[Rule]) -> AbstractSet[Expr]: | |
iteration = 1 | |
R : Set[Expr] = set() | |
while True: | |
print("iteration %d:" % iteration) | |
iteration += 1 | |
S = set(R) | |
for rule in rules: | |
# results = self.matchRule(facts | S, rule) # Jacobi method | |
results = self.matchRule(facts | R, rule) # Gauss-Seidel method | |
R = R | results | |
for r in R: | |
print("result: %s" % r) | |
if len(R) - len(S) == 0: | |
return facts | S | |
def matchRule(self, facts : AbstractSet[Expr], rule : Rule) -> Set[Expr]: | |
answers = self.matchGoals(rule.body, facts) | |
results = set() | |
for answer in answers: | |
derivedFact = rule.head.substitute(answer) | |
results.add(derivedFact) | |
return results | |
def matchGoals(self, goals : Sequence[Expr], facts : AbstractSet[Expr], binding : Binding = None) -> List[Binding]: | |
answers = [] | |
goal = goals[0] | |
for fact in facts: | |
if fact.predicate != goal.predicate: | |
continue | |
newBinding = Binding(binding) | |
if fact.unify(goal, newBinding): | |
if len(goals[1:]) == 0: | |
answers.append(newBinding) | |
else: | |
answers = answers + self.matchGoals(goals[1:], facts, newBinding) | |
return answers | |
def isVariable(term): | |
return term[0].isupper() | |
if __name__ == "__main__": | |
datalog = Datalog() | |
# EDB from section 25.5 (p. 823) of Elmasri | |
datalog.fact('parent', 'bert', 'alice') | |
datalog.fact('parent', 'bert', 'george') | |
datalog.fact('parent', 'alice', 'derek') | |
datalog.fact('parent', 'alice', 'pat') | |
datalog.fact('parent', 'derek', 'frank') | |
datalog.rule(Expr('ancestor', 'X', 'Y'), Expr('parent', 'X', 'Y')) | |
datalog.rule(Expr('ancestor', 'X', 'Y'), Expr('ancestor', 'X', 'Z'), Expr('parent', 'Z', 'Y')) | |
goals = [Expr('ancestor', 'Anc', 'Des')] # Ancestor/Descendant | |
answers = datalog.query(goals) | |
# for i in range(10000): | |
# goals = [Expr('ancestor', 'Anc', 'Des')] # Ancestor/Descendant | |
# answers = datalog.query(goals) | |
print("----------------") | |
for a in answers: | |
print("answer: %s" % a) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment