Last active
December 11, 2017 23:16
-
-
Save Grissess/a687dedab7e5597c2d0e99e9b88ce1e7 to your computer and use it in GitHub Desktop.
Test scheduler please ignore
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from minisat_model import * | |
class Intervalled(object): | |
def __init__(self, domain, *args): | |
self.domain = domain | |
self.full_name = ('_'.join(('{}',) * (len(args) + 1))).format(*args, id(self)) | |
self.interval = Interval(self.full_name, domain.t_start, domain.t_end) | |
def __repr__(self): | |
return self.full_name | |
@apply_tracking | |
def normal_model(self): | |
return NormalInterval(self.interval) | |
def solution_update(self, mdl): | |
self.interval.solution_update(mdl) | |
class ResourceBindable(Intervalled): | |
def __init__(self, domain, *args): | |
Intervalled.__init__(self, domain, *args) | |
self.requires = {} # priority (int) -> ResourceBinding | |
def bind(self, *res, priority=0, full_duration=True, optional=False): | |
rb = ResourceBinding(self.domain, res, self, full_duration, optional) | |
self.requires[priority] = rb | |
return rb | |
@property | |
def all_priorities(self): | |
ret = set() | |
for pri, rb in self.requires.items(): | |
ret.add(pri) | |
ret.update(rb.all_priorities) | |
return ret | |
@property | |
def all_required_resources(self): | |
ret = set() | |
for rb in self.requires.values(): | |
ret.update(rb.resources) | |
ret.update(rb.all_required_resources) | |
return ret | |
def resource_model(self, priority, tracker=None): | |
clauses = set() | |
for pri, rb in self.requires.items(): | |
if pri <= priority: | |
clauses.add(rb.binding_model(tracker=tracker)) | |
clauses.add(rb.resource_model(priority, tracker=tracker)) | |
return And(*clauses) | |
def full_normal_model(self, tracker=None): | |
clauses = set((self.normal_model(tracker=tracker),)) | |
for rb in self.requires.values(): | |
clauses.add(rb.normal_model(tracker=tracker)) | |
clauses.add(rb.full_normal_model(tracker=tracker)) | |
return And(*clauses) | |
class Event(ResourceBindable): | |
def __init__(self, domain, name, duration=None): | |
self.name = name | |
self.duration = duration | |
ResourceBindable.__init__(self, domain, name) | |
domain.add(self) | |
@apply_tracking | |
def duration_model(self): | |
if self.duration is None: | |
return True | |
return HasDuration(self.interval, self.duration) | |
class ResourceBinding(ResourceBindable): | |
def __init__(self, domain, resources, binding, full_duration=True, optional=False): | |
self.resources = resources | |
if not isinstance(self.resources, tuple): | |
self.resources = (self.resources,) | |
ResourceBindable.__init__(self, domain, *self.resources, binding) | |
self.in_use = tuple(Bool('{}_{}'.format(res.full_name, self.full_name)) for res in self.resources) | |
self.selected_resource = None | |
self.binding = binding | |
self.full_duration = full_duration | |
self.optional = optional | |
for resource in self.resources: | |
resource.bindings.add(self) | |
def __eq__(self, other): | |
if not isinstance(other, type(self)): | |
return False | |
return (self.resources, self.binding) == (other.resources, other.binding) | |
def __hash__(self): | |
return hash((self.resources, self.binding)) | |
@apply_tracking | |
def binding_model(self): | |
if self.full_duration: | |
core = Concurrent(self.binding.interval, self.interval) | |
else: | |
core = EntirelyWithin(self.binding.interval, self.interval) | |
quant = ExactlyOne(*self.in_use) | |
if self.optional: | |
quant = Or(quant, And(*(Not(i) for i in self.in_use))) | |
return And(quant, core) | |
def is_in_use_var(self, res): | |
return self.in_use[self.resources.index(res)] | |
def solution_update(self, mdl): | |
Intervalled.solution_update(self, mdl) | |
for idx, bv in enumerate(self.in_use): | |
#if mdl.eval(bv): | |
if mdl[bv]: | |
self.selected_resource = self.resources[idx] | |
break | |
class Resource(Intervalled): | |
def __init__(self, domain, name): | |
self.name = name | |
self.bindings = set() | |
Intervalled.__init__(self, domain, name) | |
@apply_tracking | |
def conflict_model(self): | |
clauses = set() | |
for a, b in itertools.combinations(self.bindings, 2): | |
clauses.add( | |
Implies( | |
And( | |
a.is_in_use_var(self), | |
b.is_in_use_var(self), | |
), | |
Not(Overlaps(a.interval, b.interval, False)), | |
) | |
) | |
return And(*clauses) | |
def solution_update(self, mdl): | |
Intervalled.solution_update(self, mdl) | |
for rb in self.bindings: | |
rb.solution_update(mdl) | |
class Domain(object): | |
def __init__(self, t_start, t_end): | |
self.t_start = t_start | |
self.t_end = t_end | |
self.interval = Interval('domain', self.t_start, self.t_end) | |
self.events = set() | |
def add(self, *evs): | |
self.events.update(evs) | |
for ev in evs: | |
ev.domain = self | |
@property | |
def all_priorities(self): | |
ret = set() | |
for ev in self.events: | |
ret.update(ev.all_priorities) | |
return ret | |
@property | |
def all_required_resources(self): | |
ret = set() | |
for ev in self.events: | |
ret.update(ev.all_required_resources) | |
return ret | |
@property | |
def all_resource_bindings(self): | |
ret = set() | |
for res in self.all_required_resources: | |
ret.update(res.bindings) | |
return ret | |
@apply_tracking | |
def normal_model(self): | |
return NormalInterval(self.interval) | |
def full_normal_model(self, tracker=None): | |
return And( | |
self.normal_model(tracker=tracker), | |
*(ev.full_normal_model(tracker=tracker) for ev in self.events), | |
*(res.normal_model(tracker=tracker) for res in self.all_required_resources), | |
) | |
@apply_tracking | |
def bounding_model(self): | |
return And( | |
StartsAt(self.interval, self.t_start), | |
EndsAt(self.interval, self.t_end), | |
) | |
@apply_tracking | |
def containing_model(self): | |
return And(*( | |
EntirelyWithin(self.interval, ev.interval, True) | |
for ev in self.events | |
), *( | |
EntirelyWithin(self.interval, rb.interval, True) | |
for rb in self.all_resource_bindings | |
)) | |
def full_duration_model(self, tracker=None): | |
return And(*( | |
ev.duration_model(tracker=tracker) | |
for ev in self.events | |
)) | |
def full_resource_model(self, priority, tracker=None): | |
return And(*( | |
ev.resource_model(priority, tracker=tracker) | |
for ev in self.events | |
)) | |
def full_conflict_model(self, tracker=None): | |
return And(*( | |
res.conflict_model(tracker=tracker) | |
for res in self.all_required_resources | |
)) | |
def model(self, priority, tracker=None): | |
return And( | |
self.bounding_model(tracker=tracker), | |
self.containing_model(tracker=tracker), | |
self.full_normal_model(tracker=tracker), | |
self.full_duration_model(tracker=tracker), | |
self.full_resource_model(priority, tracker=tracker), | |
self.full_conflict_model(tracker=tracker), | |
) | |
def solution_update(self, mdl): | |
self.interval.solution_update(mdl) | |
for ev in self.events: | |
ev.solution_update(mdl) | |
for res in self.all_required_resources: | |
res.solution_update(mdl) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import z3 | |
from z3 import Or, And, Not, Implies, Bool | |
TimeVar = z3.Int | |
def ExactlyOne(*conds): | |
return Or(*( | |
And(*( | |
icond if icond is ocond else Not(icond) | |
for icond in conds | |
)) | |
for ocond in conds | |
)) | |
def to_python(rat): | |
if isinstance(rat, z3.IntNumRef): | |
return rat.as_long() | |
return rat.numerator().as_long() / rat.denominator().as_long() | |
class Interval(object): | |
def __init__(self, name, min=None, max=None): | |
self.name = name | |
self.start = TimeVar('start_{}_{}'.format(name, id(self))) | |
self.end = TimeVar('end_{}_{}'.format(name, id(self))) | |
self.t_start = None | |
self.t_end = None | |
def solution_update(self, mdl): | |
self.t_start = to_python(mdl.eval(self.start)) | |
self.t_end = to_python(mdl.eval(self.end)) | |
def __repr__(self): | |
return '<Interval {}{}>'.format( | |
self.name, | |
' {}-{} ({})'.format(self.t_start, self.t_end, self.t_end - self.t_start) | |
if self.t_start is not None else '', | |
) | |
def NormalInterval(a): | |
return a.end >= a.start | |
def StartsAfter(before, after, at=False): | |
if at: | |
return before.start <= after.start | |
return before.start < after.start | |
def EndsAfter(before, after, at=False): | |
if at: | |
return before.end <= after.end | |
return before.end < after.end | |
def StartsWith(a, b): | |
return a.start == b.start | |
def EndsWith(a, b): | |
return a.end == b.end | |
def StartsAt(a, t): | |
return a.start == t | |
def EndsAt(a, t): | |
return a.end == t | |
def HasDuration(a, t): | |
return a.end == a.start + t | |
def EntirelyWithin(outer, inner, at=True): | |
return And(StartsAfter(outer, inner, at), EndsAfter(inner, outer, at)) | |
def OverlapsTail(before, after, at=False): | |
if at: | |
return And(before.end >= after.start, before.start <= after.start) | |
return And(before.end > after.start, before.start <= after.start) | |
def Overlaps(a, b, at=False): | |
return Or( | |
OverlapsTail(a, b, at), | |
OverlapsTail(b, a, at), | |
) | |
def Concurrent(a, b): | |
return And(StartsWith(a, b), EndsWith(a, b)) | |
def Disjoint(*ivs): | |
clauses = set() | |
for a, b in itertools.combinations(ivs, 2): | |
clauses.add(Not(Overlaps(a, b, False))) | |
return And(*clauses) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator, functools | |
from satispy import Variable, Cnf | |
Bool = Variable | |
#XXX hax | |
def _Variable_repr(self): | |
return '{}({!r})'.format(type(self).__name__, self.name) | |
Variable.__repr__ = _Variable_repr | |
#XXX mega hax | |
def transmute(v, nv): | |
v.__class__ = nv.__class__ | |
v.__dict__ = nv.__dict__ | |
ALWAYS_TRUE = Variable('ALWAYS_TRUE') | |
def tseytin(ex): | |
Ex.aoize(ex) | |
Ex.de_morganize(ex) | |
Ex.dn_elim(ex) | |
bt = BinTree.from_ex(ex) | |
return bt.model() | |
class Lit(object): | |
def __init__(self, var, pos=True): | |
self.var = var | |
self.pos = pos | |
def model(self): | |
if self.pos: | |
return ALWAYS_TRUE, self.var | |
else: | |
return ALWAYS_TRUE, -self.var | |
def eval_with(self, mdl): | |
if self.pos: | |
return mdl[self.var] | |
else: | |
return not mdl[self.var] | |
@classmethod | |
def from_ex(cls, n): | |
if Ex.is_var(n): | |
return Lit(n) | |
if n.T == 'Not' and Ex.is_var(n.ex): | |
return Lit(n.ex, False) | |
raise TypeError(type(n), 'not aoid, de_morganized and/or dn_elimd') | |
def __repr__(self): | |
return '{}{}'.format('' if self.pos else '-', self.var) | |
class BinTree(object): | |
def __init__(self, l, o, r): | |
self.l = l | |
self.o = o | |
self.r = r | |
self.var = Variable('bt_{}'.format(id(self))) | |
def eval_with(self, mdl): | |
if self.o == 'And': | |
return self.l.eval_with(mdl) and self.r.eval_with(mdl) | |
else: | |
return self.l.eval_with(mdl) or self.r.eval_with(mdl) | |
def model(self): | |
lex, lov = self.l.model() | |
rex, rov = self.r.model() | |
ov = self.var | |
if self.o == 'And': | |
return lex & rex & (-lov | -rov | ov) & (lov | -ov) & (rov | -ov), ov | |
else: | |
return lex & rex & (lov | rov | -ov) & (-lov | ov) & (-rov | ov), ov | |
@classmethod | |
def from_ex(cls, n): | |
if Ex.is_var(n) or n.T == 'Not': | |
return Lit.from_ex(n) | |
if n.T not in ('And', 'Or'): | |
raise TypeError(type(n)) | |
c = n.children() | |
if not c: | |
if n.T == 'And': | |
return Lit.from_ex(ALWAYS_TRUE) | |
else: | |
return Lit.from_ex(Not(ALWAYS_TRUE)) | |
if len(c) == 1: | |
return BinTree.from_ex(c[0]) | |
b = BinTree(BinTree.from_ex(c[0]), n.T, BinTree.from_ex(c[1])) | |
for ch in c[2:]: | |
b = BinTree(b, n.T, BinTree.from_ex(ch)) | |
return b | |
def __repr__(self): | |
return '{}({!r}, {!r}, {!r})'.format( | |
type(self).__name__, self.l, self.o, self.r, | |
) | |
class Ex(object): | |
VAR_NUM = 0 | |
@classmethod | |
def gen_var(cls): | |
Ex.VAR_NUM += 1 | |
return Ex.VAR_NUM - 1 | |
def children(self): | |
raise NotImplementedError() | |
def __repr__(self): | |
return '{}{!r}'.format(type(self).__name__, self.children()) | |
@staticmethod | |
def is_var(n): | |
return n.__class__ is Variable | |
@staticmethod | |
def aoize(n): | |
if Ex.is_var(n): | |
return | |
if n.T == 'Implies': | |
l, r = n.children() | |
transmute(n, Or(Not(l), r)) | |
for c in n.children(): | |
Ex.aoize(c) | |
@staticmethod | |
def de_morganize(n): | |
if Ex.is_var(n): | |
return | |
if n.T == 'Not' and (not Ex.is_var(n.ex)) and n.ex.T in ('And', 'Or'): | |
c = n.ex.children() | |
if n.ex.T == 'And': | |
inst = Or(*map(Not, c)) | |
else: | |
inst = And(*map(Not, c)) | |
transmute(n, inst) | |
for c in n.children(): | |
Ex.de_morganize(c) | |
@staticmethod | |
def dn_elim(n): | |
if Ex.is_var(n): | |
return | |
if n.T == 'Not' and (not Ex.is_var(n.ex)) and n.ex.T == 'Not': | |
transmute(n, n.ex.ex) | |
if Ex.is_var(n): | |
return | |
for c in n.children(): | |
Ex.dn_elim(c) | |
@staticmethod | |
def eval_with(n, mdl): | |
if Ex.is_var(n): | |
return mdl[n] | |
if n.T == 'Not': | |
return not Ex.eval_with(n.ex, mdl) | |
if n.T == 'And': | |
return functools.reduce(lambda x, y: x and Ex.eval_with(y, mdl), n.exs, True) | |
if n.T == 'Or': | |
return functools.reduce(lambda x, y: x or Ex.eval_with(y, mdl), n.exs, False) | |
if n.T == 'Implies': | |
return (not Ex.eval_with(n.lex, mdl)) or Ex.eval_with(n.rex, mdl) | |
raise TypeError(type(n)) | |
class ComEx(Ex): | |
def __init__(self, *exs): | |
self.exs = list(exs) | |
def children(self): | |
return self.exs | |
class UnEx(Ex): | |
def __init__(self, ex): | |
self.ex = ex | |
def children(self): | |
return [self.ex] | |
class BinEx(Ex): | |
def __init__(self, lex, rex): | |
self.lex = lex | |
self.rex = rex | |
def children(self): | |
return [self.lex, self.rex] | |
class Not(UnEx): | |
T = 'Not' | |
class And(ComEx): | |
T = 'And' | |
class Or(ComEx): | |
T = 'Or' | |
class Implies(BinEx): | |
T = 'Implies' | |
def apply_tracking(f): | |
def __inner(self, *args, f=f, tracker=None): | |
return f(self, *args) | |
return __inner | |
class FiniteDomain(object): | |
def __init__(self, name, min, max): | |
self.name = name | |
self.min = min | |
self.max = max | |
self.vars = tuple(Bool('{}_{}_{}'.format(name, val, id(self))) for val in self.values) | |
self.t_value = None | |
@property | |
def range(self): | |
return self.max - self.min | |
@property | |
def values(self): | |
return range(self.min, self.max + 1) | |
def selected(self, mdl): | |
for val, var in zip(self.values, self.vars): | |
if mdl[var]: | |
return val | |
def unique_model(self): | |
return ExactlyOne(*self.vars) | |
def solution_update(self, mdl): | |
self.t_value = self.selected(mdl) | |
def map_binpred(self, other, binpred): | |
ret = set() | |
for val, var in zip(self.values, self.vars): | |
fb_set = set() | |
for oval, ovar in zip(other.values, other.vars): | |
if not binpred(val, oval): | |
fb_set.add(Not(ovar)) | |
if fb_set: | |
ret.add(Implies(var, And(*fb_set))) | |
return And(*ret) | |
def __lt__(self, other): | |
return self.map_binpred(other, operator.lt) | |
def __le__(self, other): | |
return self.map_binpred(other, operator.le) | |
def __gt__(self, other): | |
return self.map_binpred(other, operator.gt) | |
def __ge__(self, other): | |
return self.map_binpred(other, operator.ge) | |
def __eq__(self, other): | |
return self.eq_offset(other, 0) | |
def __ne__(self, other): | |
return Not(self.__eq__(other)) | |
def set(self, value): | |
return And(*( | |
var if val == value else Not(var) | |
for val, var in zip(self.values, self.vars) | |
)) | |
def eq_offset(self, other, offset): | |
eq_set = set() | |
imp_set = set() | |
ovarmap = {oval - offset: ovar for oval, ovar in zip(other.values, other.vars)} | |
for val, var in zip(self.values, self.vars): | |
ovar = ovarmap.get(val) | |
if ovar is not None: | |
eq_set.add(And(var, ovar)) | |
else: | |
imp_set.add(Not(var)) | |
return And(*imp_set, Or(*eq_set)) | |
def __repr__(self): | |
return '<FiniteDomain {} {}-{}{}>'.format( | |
self.name, self.min, self.max, | |
' ({})'.format(self.t_value) if self.t_value is not None else '', | |
) | |
class Interval(object): | |
def __init__(self, name, min, max): | |
self.name = name | |
self.start = FiniteDomain('{}_start'.format(name), min, max) | |
self.end = FiniteDomain('{}_end'.format(name), min, max) | |
self.t_start = None | |
self.t_end = None | |
def unique_model(self): | |
return And(self.start.unique_model(), self.end.unique_model()) | |
def solution_update(self, mdl): | |
self.start.solution_update(mdl) | |
self.end.solution_update(mdl) | |
self.t_start = self.start.t_value | |
self.t_end = self.end.t_value | |
def __repr__(self): | |
return '<Interval {}{}>'.format( | |
self.name, | |
' {}-{} ({})'.format(self.t_start, self.t_end, self.t_end - self.t_start) | |
if self.t_start is not None and self.t_end is not None else '', | |
) | |
def NormalInterval(a): | |
return And(a.unique_model(), a.start <= a.end) | |
def StartsAt(a, t): | |
return a.start.set(t) | |
def EndsAt(a, t): | |
return a.end.set(t) | |
def HasDuration(a, t): | |
return a.start.eq_offset(a.end, t) | |
def StartsAfter(before, after, at=False): | |
if at: | |
return before.start <= after.start | |
return before.start < after.start | |
def EndsAfter(before, after, at=False): | |
if at: | |
return before.end <= after.end | |
return before.end < after.end | |
def StartsWith(a, b): | |
return a.start == b.start | |
def EndsWith(a, b): | |
return a.end == b.end | |
def EntirelyWithin(outer, inner, at=True): | |
return And(StartsAfter(outer, inner, at), EndsAfter(inner, outer, at)) | |
def OverlapsTail(before, after, at=False): | |
if at: | |
return And(before.end >= after.start, before.start <= after.start) | |
return And(before.end > after.start, before.start <= after.start) | |
def Overlaps(a, b, at=False): | |
return Or( | |
OverlapsTail(a, b, at), | |
OverlapsTail(b, a, at), | |
) | |
def Concurrent(a, b): | |
return And(StartsWith(a, b), EndsWith(a, b)) | |
def Disjoint(*ivs): | |
clauses = set() | |
for a, b in itertools.combinations(ivs, 2): | |
clauses.add(Not(Overlaps(a, b, False))) | |
return And(*clauses) | |
def ExactlyOne(*conds): | |
return Or(*( | |
And(*( | |
icond if icond is ocond else Not(icond) | |
for icond in conds | |
)) | |
for ocond in conds | |
)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import z3 | |
from z3 import Or, And, Not, Implies, Bool | |
TimeVar = z3.Int | |
def ExactlyOne(*conds): | |
return Or(*( | |
And(*( | |
icond if icond is ocond else Not(icond) | |
for icond in conds | |
)) | |
for ocond in conds | |
)) | |
def to_python(rat): | |
if isinstance(rat, z3.IntNumRef): | |
return rat.as_long() | |
return rat.numerator().as_long() / rat.denominator().as_long() | |
class Interval(object): | |
def __init__(self, name): | |
self.name = name | |
self.start = TimeVar('start_{}_{}'.format(name, id(self))) | |
self.end = TimeVar('end_{}_{}'.format(name, id(self))) | |
self.t_start = None | |
self.t_end = None | |
def solution_update(self, mdl): | |
self.t_start = to_python(mdl.eval(self.start)) | |
self.t_end = to_python(mdl.eval(self.end)) | |
def __repr__(self): | |
return '<Interval {}{}>'.format( | |
self.name, | |
' {}-{} ({})'.format(self.t_start, self.t_end, self.t_end - self.t_start) | |
if self.t_start is not None else '', | |
) | |
def NormalInterval(a): | |
return a.end >= a.start | |
def StartsAfter(before, after, at=False): | |
if at: | |
return before.start <= after.start | |
return before.start < after.start | |
def EndsAfter(before, after, at=False): | |
if at: | |
return before.end <= after.end | |
return before.end < after.end | |
def StartsWith(a, b): | |
return a.start == b.start | |
def EndsWith(a, b): | |
return a.end == b.end | |
def StartsAt(a, t): | |
return a.start == t | |
def EndsAt(a, t): | |
return a.end == t | |
def HasDuration(a, t): | |
return a.end == a.start + t | |
def EntirelyWithin(outer, inner, at=True): | |
return And(StartsAfter(outer, inner, at), EndsAfter(inner, outer, at)) | |
def OverlapsTail(before, after, at=False): | |
if at: | |
return And(before.end >= after.start, before.start <= after.start) | |
return And(before.end > after.start, before.start <= after.start) | |
def Overlaps(a, b, at=False): | |
return Or( | |
OverlapsTail(a, b, at), | |
OverlapsTail(b, a, at), | |
) | |
def Concurrent(a, b): | |
return And(StartsWith(a, b), EndsWith(a, b)) | |
def Disjoint(*ivs): | |
clauses = set() | |
for a, b in itertools.combinations(ivs, 2): | |
clauses.add(Not(Overlaps(a, b, False))) | |
return And(*clauses) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator, pprint | |
from ida_model import ExactlyOne, to_python, Or, And, Not, Implies, Bool, \ | |
StartsAfter, EndsAfter, StartsWith, EndsWith, EntirelyWithin, \ | |
OverlapsTail, Overlaps, Concurrent, Disjoint | |
import ida_model | |
import z3 | |
def get_vars(astex): | |
ret = set() | |
def visit(n): | |
if z3.is_const(n) and n.decl().kind() == z3.Z3_OP_UNINTERPRETED: | |
ret.add(str(n)) | |
else: | |
for c in n.children(): | |
visit(c) | |
visit(astex) | |
return ret | |
class AssertionTracker(object): | |
def __init__(self): | |
self.map = {} # name -> expression | |
def add(self, name, expr): | |
#assert name not in self.map, name | |
self.map[name] = expr | |
def apply_to(self, slv): | |
assm = set() | |
for name, expr in self.map.items(): | |
var = Bool(name) | |
slv.assert_and_track(expr, var) | |
assm.add(var) | |
return assm | |
def __repr__(self): | |
return pprint.pformat(self.map) | |
def apply_tracking(f): | |
def __inner(self, *args, f=f, tracker=None): | |
res = f(self, *args) | |
if tracker is not None: | |
nm = getattr(self, 'full_name', str(id(self))) | |
tracker.add('{}_{}'.format(nm, f.__name__), res) | |
return res | |
return __inner | |
class FiniteDomain(object): | |
def __init__(self, name, min, max): | |
self.name = name | |
self.min = min | |
self.max = max | |
self.vars = tuple(Bool('{}_{}_{}'.format(name, val, id(self))) for val in self.values) | |
self.t_value = None | |
@property | |
def range(self): | |
return self.max - self.min | |
@property | |
def values(self): | |
return range(self.min, self.max + 1) | |
def selected(self, mdl): | |
for val, var in zip(self.values, self.vars): | |
try: | |
if mdl.eval(var): | |
return val | |
except z3.Z3Exception: | |
continue | |
def unique_model(self): | |
return ExactlyOne(*self.vars) | |
def solution_update(self, mdl): | |
self.t_value = self.selected(mdl) | |
def map_binpred(self, other, binpred): | |
ret = set() | |
for val, var in zip(self.values, self.vars): | |
fb_set = set() | |
for oval, ovar in zip(other.values, other.vars): | |
if not binpred(val, oval): | |
fb_set.add(Not(ovar)) | |
if fb_set: | |
ret.add(Implies(var, And(*fb_set))) | |
return And(*ret) | |
def __lt__(self, other): | |
return self.map_binpred(other, operator.lt) | |
def __le__(self, other): | |
return self.map_binpred(other, operator.le) | |
def __gt__(self, other): | |
return self.map_binpred(other, operator.gt) | |
def __ge__(self, other): | |
return self.map_binpred(other, operator.ge) | |
def __eq__(self, other): | |
return self.eq_offset(other, 0) | |
def __ne__(self, other): | |
return Not(self.__eq__(other)) | |
def set(self, value): | |
return And(*( | |
var if val == value else Not(var) | |
for val, var in zip(self.values, self.vars) | |
)) | |
def eq_offset(self, other, offset): | |
eq_set = set() | |
imp_set = set() | |
ovarmap = {oval - offset: ovar for oval, ovar in zip(other.values, other.vars)} | |
for val, var in zip(self.values, self.vars): | |
ovar = ovarmap.get(val) | |
if ovar is not None: | |
eq_set.add(And(var, ovar)) | |
else: | |
imp_set.add(Not(var)) | |
return And(*imp_set, Or(*eq_set)) | |
def __repr__(self): | |
return '<FiniteDomain {} {}-{}{}>'.format( | |
self.name, self.min, self.max, | |
' ({})'.format(self.t_value) if self.t_value is not None else '', | |
) | |
class Interval(object): | |
def __init__(self, name, min, max): | |
self.name = name | |
self.start = FiniteDomain('{}_start'.format(name), min, max) | |
self.end = FiniteDomain('{}_end'.format(name), min, max) | |
self.t_start = None | |
self.t_end = None | |
def unique_model(self): | |
return And(self.start.unique_model(), self.end.unique_model()) | |
def solution_update(self, mdl): | |
self.start.solution_update(mdl) | |
self.end.solution_update(mdl) | |
self.t_start = self.start.t_value | |
self.t_end = self.end.t_value | |
def __repr__(self): | |
return '<Interval {}{}>'.format( | |
self.name, | |
' {}-{} ({})'.format(self.t_start, self.t_end, self.t_end - self.t_start) | |
if self.t_start is not None else '', | |
) | |
def NormalInterval(a): | |
return And(a.unique_model(), a.start <= a.end) | |
def StartsAt(a, t): | |
return a.start.set(t) | |
def EndsAt(a, t): | |
return a.end.set(t) | |
def HasDuration(a, t): | |
return a.start.eq_offset(a.end, t) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from anise import * | |
#import z3 | |
import satispy | |
dm = Domain(0, 5) | |
evs = tuple(Event(dm, i, 1) for i in 'abcdefghij') | |
rms = tuple(Resource(dm, 'room' + str(i)) for i in (1, 2)) | |
for ev in evs: | |
ev.bind(*rms) | |
#tk = AssertionTracker() | |
tk = None | |
print(time.ctime(), 'Generating model') | |
clauses = dm.model(0, tracker=tk) | |
#print('Counting vars...') | |
#n_vars = get_vars(clauses) | |
#print('Done:', len(n_vars), 'variables') | |
#print('...done, here they are:') | |
print(clauses) | |
clauses, output = tseytin(clauses) | |
#print('Tracker:', tk) | |
#print(time.ctime(), 'Simplifying...') | |
#simp = z3.simplify(clauses) | |
print(time.ctime(), 'Checking model') | |
#slv = z3.Solver() | |
slv = satispy.solver.Minisat() | |
#assm = tk.apply_to(slv) | |
#res = slv.check(*assm) | |
#slv.add(clauses) | |
#res = slv.check() | |
#if res != z3.sat: | |
# print('Not satisfiable :(') | |
# print('Result:', res) | |
# core = slv.unsat_core() | |
# print('Core:', core) | |
# print('Assumptions:') | |
# for v in core: | |
# print(v, ':') | |
# print(tk.map[str(v)]) | |
# print() | |
#else: | |
# mdl = slv.model() | |
# dm.solution_update(mdl) | |
# print(time.ctime(), 'Done!') | |
res = slv.solve(clauses) | |
if res.success: | |
dm.solution_update(res) | |
print(time.ctime(), 'Done!') | |
else: | |
print('Not satisfiable :(') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment