from typing import Iterable, Callable, TypeVar

T = TypeVar("T")
R = TypeVar("R")

def make_none_safe(func: Callable[[Iterable[T]], Iterable[R]], *, batch_size: int | None = None) -> Callable[[Iterable[T | None]], Iterable[R]]:
    """Turn `iterable -> iterable` function into one that is safe for None values.
    
    Consider if you have a function of the form `Iterable[T] -> Iterable[R]`,
    and this function is delicate and will raise an error if it encounters
    a None value in one of the inputs. This function will make a new function
    that is safe for None values, and will return None for any input that is None.

    Parameters
    ----------
    func: The function that is not safe for None values.
    batch_size: The number of items to process at a time. If None, then
        the function will be applied to the entire input at once.
    
    Returns
    -------
    A new function that is safe for None values.

    Examples
    --------
    >>> def bulk_add_one(ins):
    ...     for x in ins:
    ...         yield x + 1
    >>> safe = make_none_safe(bulk_add_one, batch_size=100)
    >>> list(safe([1, 2, None, 3]))
    [2, 3, None, 4]
    """

    if batch_size is not None and batch_size < 1:
        raise ValueError("batch_size must be None or a positive integer.")
    if batch_size is None:
        batch_size = float("inf")

    def safe_func(ins: Iterable[T | None]) -> Iterable[R]:
        ins_iter = iter(ins)
        while True:
            none_idxs = set()
            non_nones = []
            n_read = 0
            while True:
                if n_read >= batch_size:
                    break
                try:
                    v = next(ins_iter)
                except StopIteration:
                    break
                if v is None:
                    none_idxs.add(n_read)
                else:
                    non_nones.append(v)
                n_read += 1
            if n_read == 0:
                break
            results = iter(func(non_nones))
            for j in range(n_read):
                if j in none_idxs:
                    yield None
                else:
                    yield next(results)
    return safe_func

# tests
def bulk_add_one(ins: list[int]) -> list[int]:
    return [x + 1 for x in ins]
safe_unbatched = make_none_safe(bulk_add_one)
safe_batched100 = make_none_safe(bulk_add_one, batch_size=100)
safe_batched2 = make_none_safe(bulk_add_one, batch_size=2)
safe_batched1 = make_none_safe(bulk_add_one, batch_size=1)
for f in [safe_unbatched, safe_batched100, safe_batched2, safe_batched1]:
    assert list(f([1, 2, None, 3])) == [2, 3, None, 4]
    assert list(f(iter([1, 2, None, 3]))) == [2, 3, None, 4]
    assert list(f([])) == []
    assert list(f([None])) == [None]
    assert list(f([1])) == [2]