Last active
August 23, 2023 21:37
-
-
Save sebinsua/67f58b8f7e9eb78da79b17d959e9dc73 to your computer and use it in GitHub Desktop.
An implementation of a sequential generator called `follow` for traversing linked lists and other data structures in a fixed manner.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from operator import eq, attrgetter as attr | |
from typing import TypeVar, Callable, Iterator, Optional | |
eq_or_none = lambda x, y: y is None or eq(x, y) | |
T = TypeVar('T') | |
def follow(start: T, next_fn: Callable[[T], T], stop: Callable[[T, T], bool] = eq_or_none) -> Iterator[Optional[T]]: | |
yield start | |
x = start | |
y = next_fn(x) | |
while not stop(x, y): | |
x = y | |
yield x | |
# We stop iterating when `next_fn` raises a `StopIteration` exception. | |
# If it does this it means we are exiting the iterable without being | |
# able to find the next value in a path. In this case we yield `None` | |
# to indicate this. | |
try: | |
y = next_fn(x) | |
except StopIteration: | |
yield None | |
break | |
def last(iterable: Iterator[Optional[T]]) -> Optional[T]: | |
value = None | |
for item in iterable: | |
value = item | |
return value | |
class ListNode: | |
def __init__(self, value, next_node=None): | |
self.value = value | |
self.next = next_node | |
def __repr__(self): | |
return f"ListNode(value={self.value})" | |
def __eq__(self, other): | |
if not isinstance(other, ListNode): | |
return NotImplemented | |
return self.value == other.value | |
head = ListNode(1, ListNode(2, ListNode(3))) | |
assert last(follow(head, attr("next"))) == ListNode(3) | |
representatives = { | |
0: 4, | |
1: 5, | |
4: 7, | |
7: 8, | |
8: 11, | |
11: 11, | |
} | |
assert last(follow(0, lambda x: representatives[x])) == 11 | |
class TreeNode: | |
def __init__(self, value, left=None, right=None): | |
self.value = value | |
self.left = left | |
self.right = right | |
def __eq__(self, other): | |
if not isinstance(other, TreeNode): | |
return NotImplemented | |
return self.value == other.value | |
def __repr__(self): | |
return f"TreeNode(value={self.value})" | |
# Example binary search tree | |
# 5 | |
# / \ | |
# 3 8 | |
# / \ \ | |
# 2 4 10 | |
bst = TreeNode( | |
5, TreeNode(3, TreeNode(2), TreeNode(4)), TreeNode(8, None, TreeNode(10)) | |
) | |
def search_bst(bst, value): | |
def bst_next(node): | |
if value == node.value: | |
return node | |
if node.left and value < node.value: | |
return node.left | |
elif node.right and value > node.value: | |
return node.right | |
raise StopIteration( | |
"We have reached the end of the tree without finding the value." | |
) | |
return last(follow(bst, bst_next)) | |
# Searching for value 8 in the binary search tree. | |
assert search_bst(bst, 8) == TreeNode(8) | |
# Searching for missing value in the binary search tree. | |
assert search_bst(bst, 17) is None | |
class CycleDetected(Exception): | |
pass | |
def slow_and_fast(head): | |
def next_fn(tortoise_and_hare): | |
tortoise, hare = tortoise_and_hare | |
if hare and (hare.next is None or hare.next.next is None): | |
return tortoise.next, None | |
next_tortoise, next_hare = tortoise.next, hare.next.next | |
if next_tortoise == next_hare: | |
raise CycleDetected("Cycle detected") | |
return next_tortoise, next_hare | |
def stop_fn(_, tortoise_and_hare): | |
_, hare = tortoise_and_hare | |
return hare is None | |
tortoise_and_hare = head, head | |
return last(follow(tortoise_and_hare, next_fn, stop_fn)) | |
def has_cycle(head): | |
try: | |
slow_and_fast(head) | |
except CycleDetected: | |
return True | |
return False | |
# Find the middle of a linked list (of even and odd lengths) | |
even_length = ListNode(1, ListNode(2, ListNode(3, ListNode(4, ListNode(5, ListNode(6)))))) | |
mid_even, _ = slow_and_fast(even_length) | |
assert mid_even == ListNode(3, ListNode(4, ListNode(5, ListNode(6)))) | |
odd_length = ListNode(1, ListNode(2, ListNode(3, ListNode(4, ListNode(5, ListNode(6, ListNode(7))))))) | |
mid_odd, _ = slow_and_fast(odd_length) | |
assert mid_odd == ListNode(4, ListNode(5, ListNode(6, ListNode(7)))) | |
# Test cycle detection | |
head = ListNode(1, ListNode(2, ListNode(3))) | |
assert has_cycle(head) == False | |
head.next.next.next = head.next | |
assert has_cycle(head) == True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment