Skip to content

Instantly share code, notes, and snippets.

@NickCrews
Created August 25, 2024 02:47
Show Gist options
  • Save NickCrews/7d6d4096b82ad3899c176d5f9d112c0c to your computer and use it in GitHub Desktop.
Save NickCrews/7d6d4096b82ad3899c176d5f9d112c0c to your computer and use it in GitHub Desktop.
Sometimes I work with huggingface transformers pipelines. These can do batch inference on text with a signature of `Iterable[str] -> Iterable[str]`. I run into issues when using these with pyarrow string arrays, which can contain NULL values. I need NULLs to be preserved, and the order to be preserved, but I can't pass None to the huggingface pi…
from typing import Iterable, Callable, TypeVar
T = TypeVar("T")
R = TypeVar("R")
def make_none_safe(func: Callable[[Iterable[T]], Iterable[R]], *, batch_size: int | None = None) -> Callable[[Iterable[T | None]], Iterable[R]]:
"""Turn `iterable -> iterable` function into one that is safe for None values.
Consider if you have a function of the form `Iterable[T] -> Iterable[R]`,
and this function is delicate and will raise an error if it encounters
a None value in one of the inputs. This function will make a new function
that is safe for None values, and will return None for any input that is None.
Parameters
----------
func: The function that is not safe for None values.
batch_size: The number of items to process at a time. If None, then
the function will be applied to the entire input at once.
Returns
-------
A new function that is safe for None values.
Examples
--------
>>> def bulk_add_one(ins):
... for x in ins:
... yield x + 1
>>> safe = make_none_safe(bulk_add_one, batch_size=100)
>>> list(safe([1, 2, None, 3]))
[2, 3, None, 4]
"""
if batch_size is not None and batch_size < 1:
raise ValueError("batch_size must be None or a positive integer.")
if batch_size is None:
batch_size = float("inf")
def safe_func(ins: Iterable[T | None]) -> Iterable[R]:
ins_iter = iter(ins)
while True:
none_idxs = set()
non_nones = []
n_read = 0
while True:
if n_read >= batch_size:
break
try:
v = next(ins_iter)
except StopIteration:
break
if v is None:
none_idxs.add(n_read)
else:
non_nones.append(v)
n_read += 1
if n_read == 0:
break
results = iter(func(non_nones))
for j in range(n_read):
if j in none_idxs:
yield None
else:
yield next(results)
return safe_func
# tests
def bulk_add_one(ins: list[int]) -> list[int]:
return [x + 1 for x in ins]
safe_unbatched = make_none_safe(bulk_add_one)
safe_batched100 = make_none_safe(bulk_add_one, batch_size=100)
safe_batched2 = make_none_safe(bulk_add_one, batch_size=2)
safe_batched1 = make_none_safe(bulk_add_one, batch_size=1)
for f in [safe_unbatched, safe_batched100, safe_batched2, safe_batched1]:
assert list(f([1, 2, None, 3])) == [2, 3, None, 4]
assert list(f(iter([1, 2, None, 3]))) == [2, 3, None, 4]
assert list(f([])) == []
assert list(f([None])) == [None]
assert list(f([1])) == [2]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment