Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created October 11, 2021 13:44
Show Gist options
  • Save Narsil/4e1c36d7cf8477e5c1d580585860810e to your computer and use it in GitHub Desktop.
Save Narsil/4e1c36d7cf8477e5c1d580585860810e to your computer and use it in GitHub Desktop.
from transformers import pipeline
from torch.utils.data import Dataset
import tqdm
pipe = pipeline("text-classification", device=0)
class MyDataset(Dataset):
def __len__(self):
return 5000
def __getitem__(self, i):
return "This is a test"
dataset = MyDataset()
print("-" * 30)
print("Streaming no batching")
for out in tqdm.tqdm(pipe(dataset)):
pass
print("-" * 30)
print("Streaming batch_size=8")
for out in tqdm.tqdm(pipe(dataset, batch_size=8), total=len(dataset)):
pass
print("-" * 30)
print("Streaming batch_size=64")
for out in tqdm.tqdm(pipe(dataset, batch_size=64), total=len(dataset)):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment