Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Created December 6, 2018 15:23
Show Gist options
  • Save mbillingr/0f9740a46a735649d94f345a440af2c7 to your computer and use it in GitHub Desktop.
Save mbillingr/0f9740a46a735649d94f345a440af2c7 to your computer and use it in GitHub Desktop.
Python class that wraps any iterable to enable iterator composition by chaining method calls (similar to Rust's iterators).
import itertools
import functools
class FriendlyIter:
def __init__(self, it):
self.it = iter(it)
def __iter__(self):
return self
def __next__(self):
return next(self.it)
def all(self):
return all(self.it)
def any(self):
return any(self.it)
def for_each(self, func):
"""Consumes the iterator and call function with each element."""
for x in self.it:
func(x)
def consume(self):
"""Consumes the iterator. Every element is processed but the function returns nothing."""
self.last()
def count(self):
"""Consume the iterator and count the number of elements."""
i = -1
for i, x in enumerate(self.it):
pass
return i + 1
def last(self):
"""Consume the iterator and return lats element."""
for x in self.it:
pass
return x
def nth(self, n):
"""Return the n-th elements, consuming all preceding elements."""
return self.skip(n - 1).next()
def next(self):
"""Advance the iterator and return the next value."""
return next(self)
def reduce(self, func, initializer=None):
if initializer is None:
return functools.reduce(func, self.it)
else:
return functools.reduce(func, self.it, initializer)
def chain(self, *iterables):
return FriendlyIter(itertools.chain(self.it, *iterables))
def inspect(self, func=print):
"""Call function on each element and yield the element.
(x) -> (x)
"""
def inspect(it=self.it):
for x in it:
func(x)
yield x
return FriendlyIter(inspect())
def enumerate(self, start=None):
if start is None:
return FriendlyIter(enumerate(self.it))
else:
return FriendlyIter(enumerate(self.it, start))
def filter(self, pred):
"""Remove elements from iteration for which the predicate does not return true.
(x) -> (x) if pred(x) else ()
"""
return FriendlyIter(filter(pred, self.it))
def flatten(self):
def flatten(it=self.it):
# todo: error out if elements ore not iterable or silently yield one element?
for x in it:
for y in x:
yield y
return FriendlyIter(flatten())
def map(self, func):
"""Transform each element with function.
(x) -> (func(x))
"""
return FriendlyIter(map(func, self.it))
def scan(self, initial_state, func):
"""Process elements and keep track of a state variable."""
def scan(it=self.it, func=func, state=initial_state):
for x in it:
state, y = func(state, x)
yield y
return FriendlyIter(scan())
def skip(self, n):
"""Skip the first `n` elements."""
def skip(it=self.it, n=n):
for _ in zip(range(n), it):
pass
for x in it:
yield x
return FriendlyIter(skip())
def step(self, n):
"""Skip every n-th element. Note that the first element is always yielded."""
def step(it=self.it, n=n):
for i, x in enumerate(it):
if i % n == 0:
yield x
return FriendlyIter(step())
def take(self, n):
"""Stop iteration after `n` elements."""
def skip(it=self.it, n=n):
for _, x in zip(range(n), it):
yield x
return FriendlyIter(skip())
def zip(self, *iterables):
"""Combine iterators in lockstep
(x), (x1), (x2), ... -> (x, x1, x2, ...)
"""
return FriendlyIter(zip(self, *iterables))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment