Skip to content

Instantly share code, notes, and snippets.

@crap0101
Forked from rik0/lazy_set.py
Created February 8, 2011 14:46
Show Gist options
  • Save crap0101/816529 to your computer and use it in GitHub Desktop.
Save crap0101/816529 to your computer and use it in GitHub Desktop.
#coding: utf-8
# Copyright (C) 2011 by Enrico Franchi (tweaks by crap0101)
#
# This file is released under the terms of the MIT license
# http://www.opensource.org/licenses/mit-license.php
import itertools as it
import functools
def silence_generator_already_executing(generator):
try:
for element in generator:
yield element
except ValueError, e:
pass
class lazy_set(object):
"""
Trace of implementation of lazy sets in Python.
"""
def __init__(self, input_=[]):
self.input_ = iter(input_)
self.seen = set()
def __contains__(self, item):
for i in self:
if item == i:
return True
return False
def __len__(self):
return sum(1 for x in self)
def _iter(self):
for element in self.input_:
if element in self.seen:
continue
else:
self.seen.add(element)
yield element
def _flush(self):
for i in self:
pass
def add(self, obj):
self.seen.add(obj)
def copy(self):
return lazy_set(i for i in self)
def clear(self):
self.input_ = iter([])
self.seen = set()
def isdisjoint(self, other):
"""Return True if the set has no elements in common with other.
Sets are disjoint if and only if their intersection is the empty set.
"""
for item in other:
if item in self:
return False
return True
def __eq__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
return False
if len(self) != len(other):
return False
for item in self:
if item not in other:
return False
return True
def __ne__(self, other):
return (self == other)^True
def issubset(self, other):
"""set <= other
Test whether every element in the set is in other.
"""
return len(self) <= sum(1 for x in self if x in lazy_set(other))
def __le__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError('can only compare to a set')
return self.issubset(other)
def __lt__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError('can only compare to a set')
return self.issubset(other) and self != other
def issuperset(self, other):
"""set >= other
Test whether every element in other is in the set.
"""
other = lazy_set(other)
return len(other) <= sum(1 for x in other if x in self)
def __ge__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError('can only compare to a set')
return self.issuperset(other)
def __gt__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError('can only compare to a set')
return self.issuperset(other) and self != other
def union(self, *others):
"""set | other | ...
Return a new set with elements from the set and all others.
"""
new_lazy_set = self.copy()
for other in others:
new_lazy_set.update(other)
return new_lazy_set
def __or__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for |: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
new_lazy_set = self.copy()
new_lazy_set.update(other)
return new_lazy_set
__ror__ = __or__
@staticmethod
def _intersection(set1, set2):
return lazy_set(el for el in set1 if el in set2)
def intersection(self, *others):
"""set & other & ...
Return a new set with elements common to the set and all others.
"""
return functools.reduce(self._intersection, it.chain([self,], others))
def __and__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for &: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
return self.intersection(other)
__rand__ = __and__
@staticmethod
def _difference(set1, set2):
return lazy_set(el for el in set1 if el not in set2)
def difference(self, *others):
"""set - other - ...
Return a new set with elements in the set that are not in the others.
"""
return functools.reduce(self._difference, it.chain([self,], others))
def __sub__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for -: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
return self.difference(other)
__rsub__ = __sub__
def symmetric_difference(self, other):
"""set ^ other
Return a new set with elements in either the set or other but not both.
"""
diff1 = self._difference(self, other)
diff2 = self._difference(other, self)
return diff1 | diff2
def __xor__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for ^: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
return self.symmetric_difference(other)
__rxor__ = __xor__
def update(self, iterator):
self.input_ = it.chain(self.input_,
silence_generator_already_executing(iterator))
def __ior__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for |=: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
self.update(other)
return self.copy()
def intersection_update(self, *others):
"""set &= other & ...
Update the set, keeping only elements found in it and all others.
"""
new_interset = self.intersection(*others)
self.clear()
self.update(new_interset)
def __iand__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for &=: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
self.intersection_update(other)
return self.copy()
def difference_update(self, *others):
"""set -= other | ...
Update the set, removing elements found in others.
"""
new_diffset = self.difference(*others)
self.clear()
self.update(new_diffset)
def __isub__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for -=: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
self.difference_update(other)
return self.copy()
def symmetric_difference_update(self, other):
"""set ^= other
Update the set, keeping only elements found in either set,
but not in both.
"""
new_symmdiffset = self.symmetric_difference(other)
self._flush()
self.clear()
self.update(new_symmdiffset)
def __ixor__(self, other):
if not isinstance(other, (lazy_set, set, frozenset)):
raise TypeError(
"unsupported operand type(s) for ^=: '%s' and '%s'" %
(type(self).__name__, type(other).__name__))
self.symmetric_difference_update(other)
return self.copy()
def remove(self, item):
"""Remove element *item* from the set.
Raises KeyError if *item* is not contained in the set.
"""
if item not in self:
raise KeyError(item)
self.seen.remove(item)
def discard(self, item):
"""Remove element *item* from the set if it is present."""
try:
self.remove(item)
except KeyError:
pass
#if item in self:
# self.seen.remove(item)
def pop(self):
"""Remove and return an arbitrary element from the set.
Raises KeyError if the set is empty.
"""
self.update([])
return self.seen.pop()
def __iter__(self):
return it.chain(self.seen, self._iter())
if __name__ == "__main__":
import doctest
doctest.testmod()
import itertools as it
import string
import unittest
import lazy_set
class LazySetTest(unittest.TestCase):
def setUp(self):
self.set_types = (lazy_set.lazy_set, set, frozenset)
self.lazy_set = lazy_set.lazy_set(xrange(1, 4))
self.starting_value = range(1, 4)
def testCreation(self):
self.assertEqual(
sorted(self.lazy_set),
self.starting_value
)
def testUpdate(self):
self.lazy_set.update(2* x for x in range(4, 10))
self.assertEqual(
sorted(self.lazy_set),
[1, 2, 3, 8, 10, 12, 14, 16, 18]
)
def testAdd(self):
self.lazy_set.add(5)
self.assertEqual(
sorted(self.lazy_set),
[1, 2, 3, 5]
)
def testUniqueness(self):
self.lazy_set.update(xrange(5))
self.lazy_set.update(xrange(3, 8))
self.assertEqual(
sorted(self.lazy_set),
range(8)
)
def testSelfUpdateWIter(self):
self.lazy_set.update(iter(self.lazy_set))
self.assertEqual(
sorted(self.lazy_set),
self.starting_value
)
def testSelfUpdateWOIter(self):
self.lazy_set.update(self.lazy_set)
self.assertEqual(
sorted(self.lazy_set),
self.starting_value
)
def testSelfUpdateNotLats(self):
self.lazy_set.update(iter(self.lazy_set))
self.lazy_set.update(xrange(5, 8))
self.assertEqual(
sorted(self.lazy_set),
range(1, 4) + range(5, 8)
)
def testLen(self):
self.assertEqual(len(self.lazy_set), len(self.starting_value))
self.lazy_set.update(self.starting_value)
self.assertEqual(len(self.lazy_set), len(self.starting_value))
strings = ['awert', 'i', '12', 'ferhjkl', 'cvbnm', 'bar', 'baz']
lists = [range(x) for x in range(20)]
for item in it.chain(strings, lists):
lset = lazy_set.lazy_set(item)
self.assertEqual(len(lset), len(item))
numb1 = range(10)
numb2 = range(11,20)
lset = lazy_set.lazy_set(numb1)
init_len = len(lset)
for add, n in zip(it.count(1), numb2):
lset.add(n)
self.assertEqual(len(lset), init_len+add)
def testCopy(self):
new_lazy_set = self.lazy_set.copy()
self.assertEqual(self.lazy_set, new_lazy_set)
self.assertEqual(list(sorted(self.lazy_set)),
list(sorted(new_lazy_set)))
other_lazy_set = new_lazy_set.copy()
self.assertEqual(self.lazy_set, other_lazy_set)
self.lazy_set.add(999)
self.assertNotEqual(self.lazy_set, other_lazy_set)
self.assertNotEqual(self.lazy_set, new_lazy_set)
def testIsdisjoint(self):
set_strings = list(lazy_set.lazy_set(string.letters[n:n+5])
for n in range(0, len(string.letters), 5))
for s1, s2 in it.combinations(set_strings, 2):
self.assertTrue(s1.isdisjoint(s2))
self.assertTrue(s2.isdisjoint(s1))
s1.update(s2)
self.assertFalse(s1.isdisjoint(s2))
def testEquality(self):
new_lazy_set = lazy_set.lazy_set(i for i in self.lazy_set)
self.assertEqual(self.lazy_set, new_lazy_set)
self.assertFalse(self.lazy_set != new_lazy_set)
new_lazy_set.update(list(i for i in self.lazy_set))
self.assertEqual(self.lazy_set, new_lazy_set)
self.assertFalse(self.lazy_set != new_lazy_set)
self.lazy_set.update(range(100))
new_lazy_set.update(range(100,200))
self.assertNotEqual(self.lazy_set, new_lazy_set)
foo_set = '1234567'
new_lazy_set = lazy_set.lazy_set(foo_set)
for settype in self.set_types:
eq_set = settype(foo_set)
self.assertEqual(new_lazy_set, eq_set)
for othertype in (list, tuple, str):
ne_set = othertype(foo_set)
self.assertNotEqual(new_lazy_set, ne_set)
def testIsSubSuperSet(self):
def op_cmp_sub(s1,s2):
return s1 <= s2
def op_cmp_sup(s1,s2):
return s1 >= s2
other_set = self.lazy_set.copy()
self.assertTrue(self.lazy_set.issubset(other_set))
for i in range(20):
other_set.add(i)
self.assertTrue(self.lazy_set.issubset(other_set))
self.assertTrue(self.lazy_set <= other_set)
self.assertTrue(other_set.issuperset(self.lazy_set))
self.assertTrue(other_set >= self.lazy_set)
items = list(self.lazy_set)
other_set = self.lazy_set.copy()
for item in items:
other_set.remove(item)
self.assertFalse(self.lazy_set.issubset(other_set))
self.assertFalse(self.lazy_set <= other_set)
self.assertFalse(other_set.issuperset(self.lazy_set))
self.assertFalse(other_set >= self.lazy_set)
foo_set = '1234567'
new_lazy_set = lazy_set.lazy_set(foo_set)
for settype in self.set_types:
sub_set = settype(foo_set)
self.assertTrue(new_lazy_set.issubset(sub_set))
self.assertTrue(sub_set.issuperset(new_lazy_set))
for othertype in (list, tuple, str):
no_set = othertype(foo_set)
self.assertTrue(new_lazy_set.issubset(no_set))
self.assertRaises(TypeError, op_cmp_sub, new_lazy_set, no_set)
self.assertRaises(TypeError, op_cmp_sup, no_set, new_lazy_set)
#with self.assertRaises(TypeError):
# new_lazy_set <= no_set
def testUnion(self):
def op_cmp_or(s, other):
return s | other
items = [range(n, n+10) for n in range(10)]
sets = [lazy_set.lazy_set(item) for item in items]
new_lazy_set = lazy_set.lazy_set()
total_union = new_lazy_set.union(*sets)
for settype in self.set_types:
set2 = settype()
set3 = settype()
for s in sets:
set2 = set2.union(s)
set3 |= s
self.assertEqual(total_union, set2)
self.assertEqual(total_union, set3)
no_set = '1234567'
for othertype in (list, tuple, str):
self.assertRaises(TypeError, op_cmp_or, total_union, no_set)
items = [('12345', '56781'), ((1,2,3), range(11,30))]
total_set = lazy_set.lazy_set()
for fst, snd in items:
new_lazy_set = lazy_set.lazy_set(fst)
setx = new_lazy_set | lazy_set.lazy_set(snd)
new_lazy_set |= lazy_set.lazy_set(snd)
self.assertEqual(new_lazy_set, setx)
total_set |= lazy_set.lazy_set(fst) | lazy_set.lazy_set(snd)
print
t = lazy_set.lazy_set()
for i in list(i for i in it.chain(*items)):
t = t.union(i)
self.assertEqual(total_set, t)
def testIntersection(self):
items = [('123456', '136789'),
(range(5, 22), range(17, 33)),
('abcdefvghjyrfewfdger', 'rgerhhsafwfrhhtjh'),
((1,2,3), (7,8,9)),
(('a',2,()), ())]
for pair in items:
set1, set2 = map(set, pair)
lset1, lset2 = map(lazy_set.lazy_set, pair)
self.assertEqual(lset1.intersection(lset2),
set1.intersection(set2))
self.assertEqual(lset2.intersection(lset1),
set2.intersection(set1))
set1, set2 = map(lambda x: set(x), pair)
lset1, lset2 = map(lambda y: lazy_set.lazy_set(y), pair)
self.assertEqual(set1 & lset2, lset1 & lset2)
self.assertEqual(lazy_set.lazy_set(set1 & set2),
lset1 & lset2)
lset3 = lset1 & lset2
lset1 &= lset2
self.assertEqual(lset1, lset3)
def testDifference(self):
items = [('123456', '136789'),
(range(5, 22), range(17, 33)),
('abcdefvghjyrfewfdger', 'rgerhhsafwfrhhtjh'),
((1,2,3), (7,8,9)),
(('a',2,()), ())]
for pair in items:
set1, set2 = map(set, pair)
lset1, lset2 = map(lazy_set.lazy_set, pair)
self.assertEqual(lset1.difference(lset2),
set1.difference(set2))
self.assertEqual(lset2.difference(lset1),
set2.difference(set1))
set1, set2 = map(lambda x: set(x), pair)
lset1, lset2 = map(lambda y: lazy_set.lazy_set(y), pair)
self.assertEqual(lazy_set.lazy_set(set1 - set2),
lset1 - lset2) #FIX? (rcmp)
self.assertEqual(lazy_set.lazy_set(set1 - set2), lset1 - lset2)
# __isub__
lset3 = lset1 - lset2
lset1 -= lset2
self.assertEqual(lset1, lset3)
# symmetric_difference
set1, set2 = map(lambda x: set(x), pair)
lset1, lset2 = map(lambda y: lazy_set.lazy_set(y), pair)
self.assertEqual(lset1.symmetric_difference(lset2),
set1.symmetric_difference(set2))
self.assertEqual(lset2.symmetric_difference(lset1),
set2.symmetric_difference(set1))
set1, set2 = map(lambda x: set(x), pair)
lset1, lset2 = map(lambda y: lazy_set.lazy_set(y), pair)
self.assertEqual(lazy_set.lazy_set(set1 ^ set2), lset1 ^ lset2) #FIX? (rcmp)
set3 = set(set1)
set3 ^= set2
lset3 = lset1.copy()
lset3 ^= lset2
self.assertEqual(set3,set1 ^ set2)
self.assertEqual(set3,set3)
lset_orig = lset1.copy()
lset1 ^= lset2
self.assertEqual(lset1, lset3)
def testPopRemoveDiscardClear(self):
items = [range(10), 'qwerty']
for item in items:
new_set = lazy_set.lazy_set(item)
orig_set = new_set.copy()
for element in item:
new_set.remove(element)
self.assertTrue(new_set.issubset(orig_set))
for element in item:
self.assertRaises(KeyError, new_set.remove, element)
new_set.discard(element)
orig_set = lazy_set.lazy_set('fwokdkdfpdkpfkepfkwepf')
old_set = lazy_set.lazy_set('fwokdkdfpdkpfkepfkwepf')
orig_set.clear()
self.assertNotEqual(orig_set, old_set)
self.assertEqual(orig_set, lazy_set.lazy_set())
for item in items:
lset = lazy_set.lazy_set(item)
cpset = lazy_set.lazy_set(lset)
for el in cpset:
ret = lset.pop()
self.assertTrue(ret in cpset)
self.assertRaises(KeyError, lset.pop)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment