from typing import Iterable, Callable, TypeVar T = TypeVar("T") R = TypeVar("R") def make_none_safe(func: Callable[[Iterable[T]], Iterable[R]], *, batch_size: int | None = None) -> Callable[[Iterable[T | None]], Iterable[R]]: """Turn `iterable -> iterable` function into one that is safe for None values. Consider if you have a function of the form `Iterable[T] -> Iterable[R]`, and this function is delicate and will raise an error if it encounters a None value in one of the inputs. This function will make a new function that is safe for None values, and will return None for any input that is None. Parameters ---------- func: The function that is not safe for None values. batch_size: The number of items to process at a time. If None, then the function will be applied to the entire input at once. Returns ------- A new function that is safe for None values. Examples -------- >>> def bulk_add_one(ins): ... for x in ins: ... yield x + 1 >>> safe = make_none_safe(bulk_add_one, batch_size=100) >>> list(safe([1, 2, None, 3])) [2, 3, None, 4] """ if batch_size is not None and batch_size < 1: raise ValueError("batch_size must be None or a positive integer.") if batch_size is None: batch_size = float("inf") def safe_func(ins: Iterable[T | None]) -> Iterable[R]: ins_iter = iter(ins) while True: none_idxs = set() non_nones = [] n_read = 0 while True: if n_read >= batch_size: break try: v = next(ins_iter) except StopIteration: break if v is None: none_idxs.add(n_read) else: non_nones.append(v) n_read += 1 if n_read == 0: break results = iter(func(non_nones)) for j in range(n_read): if j in none_idxs: yield None else: yield next(results) return safe_func # tests def bulk_add_one(ins: list[int]) -> list[int]: return [x + 1 for x in ins] safe_unbatched = make_none_safe(bulk_add_one) safe_batched100 = make_none_safe(bulk_add_one, batch_size=100) safe_batched2 = make_none_safe(bulk_add_one, batch_size=2) safe_batched1 = make_none_safe(bulk_add_one, batch_size=1) for f in [safe_unbatched, safe_batched100, safe_batched2, safe_batched1]: assert list(f([1, 2, None, 3])) == [2, 3, None, 4] assert list(f(iter([1, 2, None, 3]))) == [2, 3, None, 4] assert list(f([])) == [] assert list(f([None])) == [None] assert list(f([1])) == [2]