Created
October 13, 2022 08:42
-
-
Save fhrzn/7991c366189dbffaf14ccdf02afff9c9 to your computer and use it in GitHub Desktop.
Code snippets for Medium article: Exploring ViT with Huggingface
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader | |
from datasets import load_dataset | |
datasets = load_dataset('imagefolder', data_dir='../input/shoe-vs-sandal-vs-boot-dataset-15k-images/Shoe vs Sandal vs Boot Dataset') | |
datasets = datasets['train'].train_test_split(test_size=.2, seed=42) | |
datasets_split = datasets['train'].train_test_split(test_size=.2, seed=42) | |
datasets['train'] = datasets_split['train'] | |
datasets['validation'] = datasets_split['test'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import ViTFeatureExtractor, AutoFeatureExtractor | |
model_ckpt = 'google/vit-base-patch16-224-in21k' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
extractor = ViTFeatureExtractor.from_pretrained(model_ckpt) | |
def batch_transform(examples): | |
# take a list of PIL images and turn into pixel values | |
inputs = extractor([x for x in examples['image']], return_tensors='pt') | |
# add the labels in | |
inputs['label'] = examples['label'] | |
return inputs | |
transformed_data = datasets.with_transform(batch_transform) | |
transformed_data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# data collator | |
def collate_fn(examples): | |
return { | |
'pixel_values': torch.stack([x['pixel_values'] for x in examples]), | |
'labels': torch.tensor([x['label'] for x in examples]) | |
} | |
# metrics | |
metric = load_metric('accuracy') | |
def compute_metrics(p): | |
labels = p.label_ids | |
preds = p.predictions.argmax(-1) | |
acc = accuracy_score(labels, preds) | |
f1 = f1_score(labels, preds, average='weighted') | |
return { | |
'accuracy': acc, | |
'f1': f1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment