Last active
April 16, 2020 13:41
-
-
Save chrislawlor/0612e8cdaa564e7ed25fd8fc7382199d to your computer and use it in GitHub Desktop.
Python iterable processing with delayed execution. Inspired by LINQ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Copyright 2020 Christopher Lawlor | |
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from typing import Any, Callable, Generator, Hashable, Iterable, List, TypeVar | |
T = TypeVar("T") | |
TExpr = Callable[[T], Any] | |
class _Element: | |
""" | |
Simple class used as a target for query expressions. It is a placeholder | |
for any given item in the query source collection. | |
""" | |
def __eq__(self, other): | |
return lambda x: x == other | |
def __ne__(self, other): | |
return lambda x: x != other | |
def __gt__(self, other): | |
return lambda x: x > other | |
def __ge__(self, other): | |
return lambda x: x >= other | |
def __lt__(self, other): | |
return lambda x: x < other | |
def __le__(self, other): | |
return lambda x: x <= other | |
def __call__(self, *args, **kwargs): | |
raise ValueError("Must be used in an expression") | |
def __repr__(self): | |
return "<class 'linq.Q'>" | |
# This is the public object that client code will use. It would have been preferable | |
# to implement our dunder method overrides on Q as class or static method, but | |
# that doesn't seem to work. Since there's no reason for client code to really ever | |
# care about instantiating a Q, as it holds no state, a singleton is a reasonable | |
# compromise | |
_Element.__name__ = "Element" | |
Element = _Element() | |
class Linq: | |
def __init__(self, collection: Iterable[T]): | |
self._collection = collection | |
# seed our operations stack with an identity operator | |
# This makes us iterable immediately, and also | |
# makes __iter__ a little simpler than it might be otherwise | |
def identity(stream): | |
yield from stream | |
self._operations: List[TExpr] = [identity] | |
# Chainable methods | |
def distinct(self, hash_=False) -> "Linq": | |
""" | |
Remove duplicates. | |
This requires keeping a set of every seen element in memory. | |
Since we only care about membership, and never need to retrieve | |
the values from the set of seen elements, ``distinct`` supports | |
an ``hash_`` option. If set, we only store the hash of each element | |
in the set of seen elements. For cases where the hash is smaller than | |
the element, this can significantly reduce memory requirements. | |
""" | |
def _distinct(collection): | |
seen = set() | |
for item in collection: | |
# use "if not" to leverage short-circuiting and only | |
# call hash() when we need to | |
comparator = item if not hash_ else hash(item) | |
if comparator not in seen: | |
yield item | |
seen.add(comparator) | |
self._operations.append(_distinct) | |
return self | |
def where(self, filter_: TExpr) -> "Linq": | |
""" | |
Allow elements meeting some condition. | |
""" | |
def where_(collection): | |
yield from filter(filter_, collection) | |
self._operations.append(where_) | |
return self | |
def apply(self, selector: TExpr) -> "Linq": | |
""" | |
Apply some transorm on received elements. | |
""" | |
def select_(collection): | |
return (selector(i) for i in collection) | |
self._operations.append(select_) | |
return self | |
# Terminating methods | |
def any(self, filter_: TExpr) -> bool: | |
return any(filter(filter_, self._collection)) | |
def count(self) -> int: | |
""" | |
Get the final count of the query's results. | |
This consumes the query. | |
""" | |
return len(self) | |
# Operators | |
@staticmethod | |
def eq(value) -> Callable[[T], bool]: | |
return lambda x: x == value | |
@staticmethod | |
def ne(value) -> Callable[[T], bool]: | |
return lambda x: x != value | |
@staticmethod | |
def gt(value) -> Callable[[T], bool]: | |
return lambda x: x > value | |
@staticmethod | |
def gte(value) -> Callable[[T], bool]: | |
return lambda x: x >= value | |
@staticmethod | |
def lt(value) -> Callable[[T], bool]: | |
return lambda x: x < value | |
@staticmethod | |
def lte(value) -> Callable[[T], bool]: | |
return lambda x: x <= value | |
# Internals | |
def __len__(self): | |
# Gets the query result count without storing | |
# results in memory | |
count = 0 | |
for _ in self: | |
count += 1 | |
return count | |
def __iter__(self): | |
# converts our list of operations from this: | |
# [f0(iter), f1(iter), f2(iter), ... fn(iter)] | |
# to this: | |
# fn(f2(f1(f0(iter)))) | |
# We always have at least our identity operator in the | |
# stack, so will never get an IndexError here | |
p = partial(self._operations[0], self._collection) | |
# func-ception | |
for operation in self._operations[1:]: | |
p = partial(operation, p()) | |
yield from p() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from linq import Linq, Element | |
def test_any(): | |
assert Linq(["foo"]).any(lambda s: s == "foo") is True | |
def test_eq(): | |
assert Linq(["foo"]).any(Linq.eq("foo")) is True | |
def test_ne(): | |
assert Linq(["foo"]).any(Linq.ne("foo")) is False | |
def test_gt(): | |
l = Linq([5, 10]).where(Linq.gt(5)) | |
assert list(l) == [10] | |
def test_gte(): | |
l = Linq([5, 6, 10]).where(Linq.gte(6)) | |
assert list(l) == [6, 10] | |
def test_lt(): | |
l = Linq([5, 10]).where(Linq.lt(10)) | |
assert list(l) == [5] | |
def test_lte(): | |
l = Linq([5, 9, 10]).where(Linq.lte(9)) | |
assert list(l) == [5, 9] | |
def test_count(): | |
assert Linq(range(10)).where(Linq.gt(5)).count() == 4 # [6, 7, 8, 9] | |
def test_always_iterable(): | |
l = Linq(["foo"]) | |
result = list(l) | |
assert result == ["foo"] | |
def test_distinct(): | |
l = Linq(["foo", "foo"]).distinct() | |
result = list(l) | |
assert result == ["foo"] | |
def test_distinct_with_hash(): | |
l = Linq(["foo", "foo"]).distinct(hash_=True) | |
assert list(l) == ["foo"] | |
def test_chain(): | |
l = Linq(["foo", "foo", "bar"]).distinct().any(lambda n: n == "bar") | |
assert l is True | |
def test_where(): | |
l = Linq(["alice", "bob"]).where(lambda n: n == "alice") | |
result = list(l) | |
assert result == ["alice"] | |
def test_apply(): | |
l = Linq(["alice", "bob"]).apply(lambda x: len(x)).where(Linq.gt(3)) | |
assert list(l) == [5] | |
def test_Q_eq(): | |
assert (Element == "foo")("foo") is True | |
assert (Element == "foo")("bar") is False | |
def test_Q_ne(): | |
assert (Element != "foo")("foo") is False | |
assert (Element != "foo")("bar") is True | |
def test_Q_gt(): | |
assert (Element > 5)(6) is True | |
assert (Element > 5)(5) is False | |
assert (Element > 5)(4) is False | |
def test_Q_gte(): | |
assert (Element >= 5)(5) is True | |
assert (Element >= 5)(6) is True | |
assert (Element >= 5)(4) is False | |
def test_any_with_Q(): | |
assert Linq(["foo", "bar"]).any(Element == "foo") is True | |
def test_value_error_with_naked_Element(): | |
with pytest.raises(ValueError): | |
Linq(["a"]).any(Element) | |
if __name__ == "__main__": | |
pytest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment