Created
November 9, 2022 23:31
-
-
Save nateraw/6f2a291928101977af3c8868cc3f6a24 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# Data | |
wget -nc --no-check-certificate https://www.crcv.ucf.edu/data/UCF101/UCF101.rar | |
unrar x UCF101.rar | |
# Annotations | |
wget -nc --no-check-certificate https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip | |
unzip UCF101TrainTestSplits-RecognitionTask.zip | |
Probably some extra unnecessary imports here | |
""" | |
from pathlib import Path | |
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths | |
from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset | |
import pytorchvideo.data | |
from transformers import TrainingArguments, Trainer | |
from torchmetrics import Accuracy | |
import torch | |
from pytorchvideo.transforms import ( | |
ApplyTransformToKey, | |
Normalize, | |
RandomShortSideScale, | |
UniformTemporalSubsample, | |
MixVideo, | |
) | |
import os | |
from torchvision.transforms import ( | |
Compose, | |
Lambda, | |
RandomCrop, | |
RandomHorizontalFlip, | |
Resize, | |
) | |
import logging | |
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification, set_seed | |
from transformers.modeling_utils import unwrap_model | |
import itertools | |
from transformers.trainer_utils import get_last_checkpoint | |
from pytorchvideo.losses.soft_target_cross_entropy import SoftTargetCrossEntropyLoss | |
logger = logging.getLogger(__name__) | |
class LimitDataset(torch.utils.data.Dataset): | |
def __init__(self, dataset): | |
super().__init__() | |
self.dataset = dataset | |
self.dataset_iter = itertools.chain.from_iterable( | |
itertools.repeat(iter(dataset), 2) | |
) | |
def __getitem__(self, index): | |
return next(self.dataset_iter) | |
def __len__(self): | |
return self.dataset.num_videos | |
class VideoClassificationTrainer(Trainer): | |
def __init__(self, *args, batch_transform=None, **kwargs): | |
"""Here we add a batch_transform argument to the Trainer for MixUp/CutMix""" | |
super().__init__(*args, **kwargs) | |
self.batch_transform = batch_transform | |
def training_step(self, model, inputs) -> torch.Tensor: | |
"""Do a training step, applying batch transform if set""" | |
model.train() | |
if self.batch_transform is not None: | |
# Shouldbe: (B, C, T, H, W) | |
# Incoming: (B, T, C, H, W) | |
pixel_values, inputs['labels'] = self.batch_transform( | |
inputs['pixel_values'].permute(0, 2, 1, 3, 4), | |
inputs['labels'] | |
) | |
# Have to set it back to the original shape | |
inputs['pixel_values'] = pixel_values.permute(0, 2, 1, 3, 4) | |
inputs = self._prepare_inputs(inputs) | |
with self.compute_loss_context_manager(): | |
loss = self.compute_loss(model, inputs) | |
if self.args.n_gpu > 1: | |
loss = loss.mean() # mean() to average on multi-gpu parallel training | |
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: | |
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` | |
loss = loss / self.args.gradient_accumulation_steps | |
if self.do_grad_scaling: | |
self.scaler.scale(loss).backward() | |
else: | |
loss.backward() | |
return loss.detach() | |
def compute_loss(self, model, inputs, return_outputs=False): | |
""" | |
Computes SoftTargetCrossEntropyLoss when batch_transform is set | |
""" | |
if self.batch_transform is not None and "labels" in inputs: | |
labels = inputs.pop("labels") | |
else: | |
labels = None | |
outputs = model(**inputs) | |
# Save past state if it exists | |
# TODO: this needs to be fixed and made cleaner later. | |
if self.args.past_index >= 0: | |
self._past = outputs[self.args.past_index] | |
if labels is not None: | |
loss = SoftTargetCrossEntropyLoss()(outputs.logits, labels) | |
else: | |
if isinstance(outputs, dict) and "loss" not in outputs: | |
raise ValueError( | |
"The model did not return a loss from the inputs, only the following keys: " | |
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." | |
) | |
# We don't use .loss here since the model may return tuples instead of ModelOutput. | |
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] | |
return (loss, outputs) if return_outputs else loss | |
def main( | |
root='./', | |
seed=1337, | |
batch_size=8, | |
num_epochs=40, | |
model_name='MCG-NJU/videomae-base', | |
output_dir='videomae-base-finetuned-ucf101-nomixup', | |
do_train=True, | |
overwrite_output_dir=False, | |
resume_from_checkpoint='/home/paperspace/Documents/video-classification/videomae-base-finetuned-ucf101/checkpoint-1500', # None | |
use_batch_transform=True, | |
): | |
videos_path = Path(root) / 'UCF-101' | |
annotations_path = Path(root) / 'ucfTrainTestlist' | |
train_annotations_path = annotations_path / 'trainlist01.txt' | |
test_annotations_path = annotations_path / 'testlist01.txt' | |
classes = sorted(x.name for x in videos_path.glob("*") if x.is_dir()) | |
label2id = {label: i for i, label in enumerate(classes)} | |
id2label = {i: label for i, label in enumerate(classes)} | |
# Detecting last checkpoint. | |
last_checkpoint = None | |
if Path(output_dir).is_dir() and do_train and not overwrite_output_dir: | |
last_checkpoint = get_last_checkpoint(output_dir) | |
if last_checkpoint is None and len(os.listdir(output_dir)) > 0: | |
raise ValueError( | |
f"Output directory ({output_dir}) already exists and is not empty. " | |
"Use --overwrite_output_dir to overcome." | |
) | |
elif last_checkpoint is not None and resume_from_checkpoint is None: | |
logger.info( | |
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " | |
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." | |
) | |
set_seed(seed) | |
feature_extractor = VideoMAEFeatureExtractor.from_pretrained(model_name) | |
model = VideoMAEForVideoClassification.from_pretrained( | |
model_name, | |
num_labels=len(classes), | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True, | |
) | |
mean = feature_extractor.image_mean | |
std = feature_extractor.image_std | |
resize_to = feature_extractor.size | |
num_frames_to_sample = model.config.num_frames | |
frame_sample_rate = 4 | |
clip_duration = num_frames_to_sample * frame_sample_rate / 30.0 | |
###################################################################### | |
# Prepare training and validation labeled video paths | |
###################################################################### | |
train_paths = LabeledVideoPaths.from_csv(train_annotations_path) | |
# Ugh the official annotations are indexed at 1 instead of 0 so gotta subtract | |
train_paths._paths_and_labels = [ | |
(path, label - 1) for path, label in train_paths._paths_and_labels | |
] | |
# Annotations don't include labels for test set even though they're labeled. | |
# Adding them here so we can use the test set for validation. | |
val_paths = LabeledVideoPaths.from_csv(test_annotations_path) | |
val_paths._paths_and_labels = [ | |
(path, label2id[path.split('/')[0]]) for (path, label) in val_paths._paths_and_labels | |
] | |
train_paths._path_prefix = './UCF-101' | |
val_paths._path_prefix = './UCF-101' | |
###################################################################### | |
# Prepare training and validation transforms | |
###################################################################### | |
train_transform = Compose( | |
[ | |
ApplyTransformToKey( | |
key="video", | |
transform=Compose( | |
[ | |
UniformTemporalSubsample(num_frames_to_sample), | |
Lambda(lambda x: x / 255.0), | |
Normalize(mean, std), | |
RandomShortSideScale(min_size=256, max_size=320), | |
RandomCrop(resize_to), | |
RandomHorizontalFlip(p=0.5), | |
] | |
), | |
), | |
] | |
) | |
val_transform = Compose( | |
[ | |
ApplyTransformToKey( | |
key="video", | |
transform=Compose( | |
[ | |
UniformTemporalSubsample(num_frames_to_sample), | |
Lambda(lambda x: x / 255.0), | |
Normalize(mean, std), | |
Resize((resize_to, resize_to)), | |
] | |
), | |
), | |
] | |
) | |
train_dataset = LimitDataset( | |
LabeledVideoDataset( | |
train_paths, | |
clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration), | |
decode_audio=False, | |
transform=train_transform, | |
) | |
) | |
# TODO - clips should be sampled w/ certain number of samples per video for correct evaluation calculation | |
# pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler | |
val_dataset = LimitDataset( | |
LabeledVideoDataset( | |
val_paths, | |
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration), | |
decode_audio=False, | |
transform=val_transform, | |
) | |
) | |
top_1_acc = Accuracy() | |
top_5_acc = Accuracy(top_k=2) | |
def compute_metrics(eval_pred): | |
"""Computes accuracy on a batch of predictions""" | |
preds, label_ids = torch.tensor(eval_pred.predictions), torch.tensor(eval_pred.label_ids) | |
return { | |
'eval_accuracy': top_1_acc(preds, label_ids), | |
'eval_accuracy_top5': top_5_acc(preds, label_ids) | |
} | |
def collate_fn(examples): | |
# permute to (num_frames, num_channels, height, width) | |
# TODO if we are doing batch transform, maybe dont do this permute? | |
pixel_values = torch.stack( | |
[example["video"].permute(1, 0, 2, 3) for example in examples] | |
) | |
labels = torch.tensor([example["label"] for example in examples]) | |
return {"pixel_values": pixel_values, "labels": labels} | |
# https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/transforms/mix.py | |
batch_transform = MixVideo( | |
mixup_alpha=0.8, | |
cutmix_prob=0.5, | |
cutmix_alpha=1.0, | |
label_smoothing=0.1, | |
num_classes=len(classes), | |
) if use_batch_transform else None | |
args = TrainingArguments( | |
output_dir, | |
remove_unused_columns=False, | |
evaluation_strategy="steps", | |
eval_steps=500, | |
save_strategy="steps", | |
learning_rate=5e-4, | |
weight_decay=0.05, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
gradient_accumulation_steps=8, | |
warmup_ratio=0.125, | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
push_to_hub=False, | |
num_train_epochs=num_epochs, | |
seed=seed, | |
dataloader_num_workers=8, # Play with this to find right value. | |
dataloader_drop_last=True, | |
report_to='wandb', | |
lr_scheduler_type='cosine', | |
fp16=True, # more memory usage for some reason...same with bf16 | |
) | |
trainer = VideoClassificationTrainer( | |
model, | |
args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
tokenizer=feature_extractor, | |
compute_metrics=compute_metrics, | |
data_collator=collate_fn, | |
batch_transform=batch_transform | |
) | |
checkpoint = None | |
if resume_from_checkpoint is not None: | |
checkpoint = resume_from_checkpoint | |
elif last_checkpoint is not None: | |
checkpoint = last_checkpoint | |
train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
metrics = train_result.metrics | |
trainer.save_model() # Saves the tokenizer too for easy upload | |
trainer.log_metrics("train", metrics) | |
trainer.save_metrics("train", metrics) | |
trainer.save_state() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Oof this hurts.