Created
June 18, 2010 16:53
-
-
Save christian-oudard/443891 to your computer and use it in GitHub Desktop.
Python flatten
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Flatten a nested list structure. | |
* Works for nested structures of lists, tuples, generators, or any other iterable. | |
* Special-cases string types and treats them as non-iterable. | |
* Is not limited to the system recursion limit. | |
* Yields items from the structure instead of constructing a new list, and can | |
work on non-terminating generators. | |
This is basically a non-recursive version of the following: | |
def flatten(iterable): | |
iterable = iter(iterable) | |
for item in iterable: | |
if hasattr(item, '__iter__') and not isinstance(item, (str, bytes)): | |
for i in flatten(item): | |
yield i | |
else: | |
yield item | |
Tested on python 3.1, but easy to port to 2.6. Simply make it test for | |
basestring instead of (str, bytes). | |
>>> list(flatten([])) | |
[] | |
>>> list(flatten([1, []])) | |
[1] | |
>>> list(flatten([[0], 1])) | |
[0, 1] | |
>>> list(flatten([1, [2, [3, 4]]])) | |
[1, 2, 3, 4] | |
>>> list(flatten((1, (2, 3)))) | |
[1, 2, 3] | |
>>> list(flatten(['one', ['two', ['three', 'four']]])) | |
['one', 'two', 'three', 'four'] | |
>>> list(flatten([1, 2, [3, 4], (5,6), [7, [8, [9, [10]]]]])) | |
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | |
>>> def make_nested_list(n): | |
... result = [] | |
... for i in range(n): | |
... result = [result, i] | |
... return result | |
... | |
>>> import sys | |
>>> n = sys.getrecursionlimit() + 1 | |
>>> assert list(range(n)) == list(flatten(make_nested_list(n))) | |
>>> def nested_gen(i=0): | |
... yield i | |
... yield nested_gen(i + 1) | |
... | |
>>> n = sys.getrecursionlimit() + 1 | |
>>> from itertools import islice | |
>>> assert list(range(n)) == list(islice(flatten(nested_gen()), n)) | |
""" | |
def flatten(iterable): | |
iterable = iter(iterable) | |
stack = [] | |
while True: | |
for item in iterable: | |
if hasattr(item, '__iter__') and \ | |
not isinstance(item, (str, bytes)): | |
stack.append(iterable) | |
iterable = iter(item) | |
break | |
else: | |
yield item | |
else: | |
if not stack: | |
return | |
iterable = stack.pop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment