Created
July 30, 2019 22:19
-
-
Save Palisand/9d4304ca23fe3655c4a671749b96920c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from typing import Callable, Dict, Iterable, List, Optional | |
T = TypeVar("T") | |
U = TypeVar("U") | |
def group_by( | |
func: Callable[[T], U], | |
coll: Iterable[T], | |
xform: Optional[Callable[[T], Any]] = None, | |
) -> Dict[U, List]: | |
""" | |
Returns a dictionary of keys, generated via the supplied func, to lists of items | |
from the supplied collection. | |
If desired, these items can be transformed before they are added to the dictionary. | |
""" | |
ddict: Dict[U, List] = defaultdict(list) | |
for item in coll: | |
ddict[func(item)].append(xform(item) if xform is not None else item) | |
return ddict | |
class TestGroupBy: | |
def test_without_xform(self): | |
assert {"> 5": [7, 9, 6], "< 5": [4, 1, 2]} == group_by( | |
lambda i: "> 5" if i > 5 else "< 5", [4, 7, 1, 2, 9, 6] | |
) | |
def test_with_xform(self): | |
assert {"> 5": ["7", "9", "6"], "< 5": ["4", "1", "2"]} == group_by( | |
lambda i: "> 5" if i > 5 else "< 5", [4, 7, 1, 2, 9, 6], str | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment