Created
May 5, 2016 09:47
-
-
Save milesrout/991a1c7046e5379acb4cbf3791e1370c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import copy, deepcopy | |
import functools | |
import itertools | |
import operator | |
class Attribute(tuple): | |
pass | |
class Instance(tuple): | |
pass | |
class WildcardType: | |
def __repr__(self): | |
return '?' | |
def __eq__(self, other): | |
return is_wildcard(other) | |
def __hash__(self): | |
return object.__hash__(Wildcard) | |
Wildcard = WildcardType() | |
def is_wildcard(x): | |
return isinstance(x, WildcardType) | |
def prod(iterable): | |
return functools.reduce(operator.mul, iterable, 1) | |
def parse_dataset(dataset): | |
lines = dataset.splitlines() | |
stripped_lines = map(str.strip, lines) | |
filtered = [line for line in stripped_lines if not line.startswith('#')] | |
attributes = [] | |
instances = [] | |
flag = False | |
for line in filtered: | |
if flag: | |
if line != '': | |
instances.append(line) | |
else: | |
if line == '': | |
flag = True | |
else: | |
attributes.append(line) | |
attributes = [Attribute(attr.split(', ')) for attr in attributes] | |
instances = [Instance(inst.split(', ')) for inst in instances] | |
return (attributes, instances) | |
def print_statistics(dataset): | |
attributes, instances = parse_dataset(dataset) | |
attr_lengths = [len(attr) for attr in attributes] | |
num_instances = prod(attr_lengths) | |
num_poss_distinct_concepts = 2**num_instances | |
num_sema_distinct_hypotheses = 1 + prod(x+1 for x in attr_lengths) | |
num_distinct_instances = len(set(instances)) | |
print(num_instances) | |
print(num_poss_distinct_concepts) | |
print(num_poss_distinct_concepts // (2**num_distinct_instances)) | |
print(num_sema_distinct_hypotheses) | |
class Hypothesis: | |
pass | |
class Conjunction(Hypothesis, tuple): | |
"""Represents an hypothesis as conjunction of constraints.""" | |
attrs = None | |
def match(self, instance): | |
return self >= Conjunction(instance[:-1]) | |
def min_gen(self, inst): | |
return [self._min_gen(inst)] | |
def _min_gen(self, inst): | |
features = [] | |
for i in range(len(self)): | |
if self[i] is None: | |
features.append(inst[i]) | |
elif is_wildcard(self[i]) or self[i] == inst[i]: | |
features.append(self[i]) | |
else: | |
features.append(Wildcard) | |
return Conjunction(tuple(features)) | |
def min_spec(self, neginst, attributes): | |
return list(self._min_spec(neginst, attributes)) | |
def _min_spec(self, neginst, attributes): | |
for i in range(len(self)): | |
if is_wildcard(self[i]): | |
for v in attributes[i]: | |
hypothesis = Conjunction(v if i == j else self[j] for j in range(len(self))) | |
if not hypothesis.match(neginst): | |
yield hypothesis | |
def is_null(self): | |
"""Return whether this represents a null hypothesis.""" | |
return None in self | |
def __hash__(self): | |
if self.is_null(): | |
return tuple.__hash__(Conjunction(None for x in self)) | |
return tuple.__hash__(self) | |
def __eq__(self, other): | |
if self.is_null(): | |
return other.is_null() | |
return tuple.__eq__(self, other) | |
def __ne__(self, other): | |
if self.is_null(): | |
return not other.is_null() | |
return tuple.__ne__(self, other) | |
def __ge__(self, other): | |
return bool(prod(is_wildcard(s) or o is None or s == o for (s, o) in zip(self, other))) | |
def __le__(self, other): | |
return other >= self | |
def __gt__(self, other): | |
return self >= other and self != other | |
def __lt__(self, other): | |
return self <= other and self != other | |
def ccge(h1, h2): | |
return h1 >= h2 | |
def positive_step(d, S, G): | |
G = { g for g in G if g.match(d) } | |
S_nonmatching = { s for s in S if not s.match(d) } | |
S -= S_nonmatching | |
for s in S_nonmatching: | |
S.update(h for h in s.min_gen(d) if h.match(d) and some_more_general_or_eq(h, G)) | |
S.difference_update([h for h in S if some_more_specific(h, S)]) | |
return (S, G) | |
def negative_step(d, S, G, attrs): | |
S = { s for s in S if not s.match(d) } | |
G_matching = { g for g in G if g.match(d) } | |
G -= G_matching | |
for g in G_matching: | |
G.update(h for h in g.min_spec(d, attrs) if not h.match(d) and some_more_specific_or_eq(h, S)) | |
G.difference_update([h for h in G if some_more_general(h, G)]) | |
return (S, G) | |
def some_more_general_or_eq(h, G): | |
for g in G: | |
if g >= h: | |
return True | |
return False | |
def some_more_general(h, G): | |
for g in G: | |
if g > h: | |
return True | |
return False | |
def some_more_specific_or_eq(h, S): | |
for s in S: | |
if s <= h: | |
return True | |
return False | |
def some_more_specific(h, S): | |
for s in S: | |
if s < h: | |
return True | |
return False | |
def cea_step(inst, S, G, attributes): | |
if inst[-1] == '+': | |
return positive_step(inst, S, G) | |
else: | |
return negative_step(inst, S, G, attributes) | |
def cea_trace(dataset): | |
attributes, instances = parse_dataset(dataset) | |
Conjunction.attrs = attributes | |
N = len(attributes) | |
initial_S = { Conjunction(None for i in range(N)) } | |
initial_G = { Conjunction(Wildcard for i in range(N)) } | |
S, G = initial_S, initial_G | |
S_trace, G_trace = [initial_S], [initial_G] | |
for inst in instances: | |
S = deepcopy(S) | |
G = deepcopy(G) | |
S, G = cea_step(inst, S, G, attributes) | |
S_trace.append(S) | |
G_trace.append(G) | |
return S_trace, G_trace | |
def cea(dataset_str): | |
S_trace, G_trace = cea_trace(dataset_str) | |
return S_trace[-1], G_trace[-1] | |
def requires_voting(S, G, instance): | |
"""Returns whether the given instance requires voting to be classified. | |
If all of S matches the instance, then it can be confidently classified as | |
being positive. If all of G fails to match the instance, then it can be | |
confidently classified as being negative. Otherwise, it requires voting. | |
""" | |
inst = Instance(instance.split(', ') + [None]) | |
if all(s.match(inst) for s in S): | |
return False | |
if all(not g.match(inst) for g in G): | |
return False | |
return True | |
def enumerate_vs(S, G): | |
"""Enumerate all hypotheses contained in the version space bounded by S and G.""" | |
hypotheses = set() | |
hypotheses.update(S) | |
for s in S: | |
for g in G: | |
hypotheses.update(set.union(*list(conjunctions_between(s, g)))) | |
return hypotheses | |
def conjunctions_between(s, g): | |
"""Returns all conjunctions between s and g.""" | |
places = [i for i in range(len(s)) if s[i] != g[i]] | |
for ss in all_subsets(places): | |
yield conjunctions_between_helper(s, g, ss) | |
def conjunctions_between_helper(s, g, ss): | |
conj = Conjunction(g[i] if (i in ss) else s[i] for i in range(len(s))) | |
if conj.is_null(): | |
return { Conjunction(hyp) for hyp in itertools.product(*Conjunction.attrs) } | |
else: | |
return { conj } | |
def all_subsets(lst): | |
"""Returns all subsets of the given list.""" | |
return itertools.chain(*all_subsets_nested(lst)) | |
def all_subsets_nested(lst): | |
return list(itertools.combinations(lst, i) for i in range(0, len(lst) + 1)) | |
def vote(VS, instance_str): | |
inst = Instance(instance_str.split(', ') + [None]) | |
pos_count = 0 | |
neg_count = 0 | |
for h in VS: | |
if h.match(inst): | |
pos_count += 1 | |
else: | |
neg_count += 1 | |
return (pos_count, neg_count) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment