Skip to content

Instantly share code, notes, and snippets.

@fhrzn
Created October 13, 2022 08:42
Show Gist options
  • Save fhrzn/7991c366189dbffaf14ccdf02afff9c9 to your computer and use it in GitHub Desktop.
Save fhrzn/7991c366189dbffaf14ccdf02afff9c9 to your computer and use it in GitHub Desktop.
Code snippets for Medium article: Exploring ViT with Huggingface
from torch.utils.data import DataLoader
from datasets import load_dataset
datasets = load_dataset('imagefolder', data_dir='../input/shoe-vs-sandal-vs-boot-dataset-15k-images/Shoe vs Sandal vs Boot Dataset')
datasets = datasets['train'].train_test_split(test_size=.2, seed=42)
datasets_split = datasets['train'].train_test_split(test_size=.2, seed=42)
datasets['train'] = datasets_split['train']
datasets['validation'] = datasets_split['test']
from transformers import ViTFeatureExtractor, AutoFeatureExtractor
model_ckpt = 'google/vit-base-patch16-224-in21k'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
extractor = ViTFeatureExtractor.from_pretrained(model_ckpt)
def batch_transform(examples):
# take a list of PIL images and turn into pixel values
inputs = extractor([x for x in examples['image']], return_tensors='pt')
# add the labels in
inputs['label'] = examples['label']
return inputs
transformed_data = datasets.with_transform(batch_transform)
transformed_data
# data collator
def collate_fn(examples):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in examples]),
'labels': torch.tensor([x['label'] for x in examples])
}
# metrics
metric = load_metric('accuracy')
def compute_metrics(p):
labels = p.label_ids
preds = p.predictions.argmax(-1)
acc = accuracy_score(labels, preds)
f1 = f1_score(labels, preds, average='weighted')
return {
'accuracy': acc,
'f1': f1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment