Skip to content

Instantly share code, notes, and snippets.

@maziyarpanahi
Created August 14, 2022 17:10
Show Gist options
  • Save maziyarpanahi/da544aa052dd81654301fec063333b19 to your computer and use it in GitHub Desktop.
Save maziyarpanahi/da544aa052dd81654301fec063333b19 to your computer and use it in GitHub Desktop.
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import pipeline
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model = model.to(device)
pipe = pipeline("image-classification", model=model, feature_extractor=feature_extractor, device=0)
for batch_size in [1, 8, 32, 64, 128, 256, 512, 1024]:
print("-" * 30)
print(f"Streaming batch_size={batch_size}")
for out in tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment