Created
August 25, 2024 02:47
-
-
Save NickCrews/7d6d4096b82ad3899c176d5f9d112c0c to your computer and use it in GitHub Desktop.
Sometimes I work with huggingface transformers pipelines. These can do batch inference on text with a signature of `Iterable[str] -> Iterable[str]`. I run into issues when using these with pyarrow string arrays, which can contain NULL values. I need NULLs to be preserved, and the order to be preserved, but I can't pass None to the huggingface pi…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Iterable, Callable, TypeVar | |
T = TypeVar("T") | |
R = TypeVar("R") | |
def make_none_safe(func: Callable[[Iterable[T]], Iterable[R]], *, batch_size: int | None = None) -> Callable[[Iterable[T | None]], Iterable[R]]: | |
"""Turn `iterable -> iterable` function into one that is safe for None values. | |
Consider if you have a function of the form `Iterable[T] -> Iterable[R]`, | |
and this function is delicate and will raise an error if it encounters | |
a None value in one of the inputs. This function will make a new function | |
that is safe for None values, and will return None for any input that is None. | |
Parameters | |
---------- | |
func: The function that is not safe for None values. | |
batch_size: The number of items to process at a time. If None, then | |
the function will be applied to the entire input at once. | |
Returns | |
------- | |
A new function that is safe for None values. | |
Examples | |
-------- | |
>>> def bulk_add_one(ins): | |
... for x in ins: | |
... yield x + 1 | |
>>> safe = make_none_safe(bulk_add_one, batch_size=100) | |
>>> list(safe([1, 2, None, 3])) | |
[2, 3, None, 4] | |
""" | |
if batch_size is not None and batch_size < 1: | |
raise ValueError("batch_size must be None or a positive integer.") | |
if batch_size is None: | |
batch_size = float("inf") | |
def safe_func(ins: Iterable[T | None]) -> Iterable[R]: | |
ins_iter = iter(ins) | |
while True: | |
none_idxs = set() | |
non_nones = [] | |
n_read = 0 | |
while True: | |
if n_read >= batch_size: | |
break | |
try: | |
v = next(ins_iter) | |
except StopIteration: | |
break | |
if v is None: | |
none_idxs.add(n_read) | |
else: | |
non_nones.append(v) | |
n_read += 1 | |
if n_read == 0: | |
break | |
results = iter(func(non_nones)) | |
for j in range(n_read): | |
if j in none_idxs: | |
yield None | |
else: | |
yield next(results) | |
return safe_func | |
# tests | |
def bulk_add_one(ins: list[int]) -> list[int]: | |
return [x + 1 for x in ins] | |
safe_unbatched = make_none_safe(bulk_add_one) | |
safe_batched100 = make_none_safe(bulk_add_one, batch_size=100) | |
safe_batched2 = make_none_safe(bulk_add_one, batch_size=2) | |
safe_batched1 = make_none_safe(bulk_add_one, batch_size=1) | |
for f in [safe_unbatched, safe_batched100, safe_batched2, safe_batched1]: | |
assert list(f([1, 2, None, 3])) == [2, 3, None, 4] | |
assert list(f(iter([1, 2, None, 3]))) == [2, 3, None, 4] | |
assert list(f([])) == [] | |
assert list(f([None])) == [None] | |
assert list(f([1])) == [2] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment