Skip to content

Instantly share code, notes, and snippets.

@nateraw
Created November 9, 2022 23:31
Show Gist options
  • Save nateraw/6f2a291928101977af3c8868cc3f6a24 to your computer and use it in GitHub Desktop.
Save nateraw/6f2a291928101977af3c8868cc3f6a24 to your computer and use it in GitHub Desktop.
"""
# Data
wget -nc --no-check-certificate https://www.crcv.ucf.edu/data/UCF101/UCF101.rar
unrar x UCF101.rar
# Annotations
wget -nc --no-check-certificate https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip
unzip UCF101TrainTestSplits-RecognitionTask.zip
Probably some extra unnecessary imports here
"""
from pathlib import Path
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset
import pytorchvideo.data
from transformers import TrainingArguments, Trainer
from torchmetrics import Accuracy
import torch
from pytorchvideo.transforms import (
ApplyTransformToKey,
Normalize,
RandomShortSideScale,
UniformTemporalSubsample,
MixVideo,
)
import os
from torchvision.transforms import (
Compose,
Lambda,
RandomCrop,
RandomHorizontalFlip,
Resize,
)
import logging
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification, set_seed
from transformers.modeling_utils import unwrap_model
import itertools
from transformers.trainer_utils import get_last_checkpoint
from pytorchvideo.losses.soft_target_cross_entropy import SoftTargetCrossEntropyLoss
logger = logging.getLogger(__name__)
class LimitDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
self.dataset_iter = itertools.chain.from_iterable(
itertools.repeat(iter(dataset), 2)
)
def __getitem__(self, index):
return next(self.dataset_iter)
def __len__(self):
return self.dataset.num_videos
class VideoClassificationTrainer(Trainer):
def __init__(self, *args, batch_transform=None, **kwargs):
"""Here we add a batch_transform argument to the Trainer for MixUp/CutMix"""
super().__init__(*args, **kwargs)
self.batch_transform = batch_transform
def training_step(self, model, inputs) -> torch.Tensor:
"""Do a training step, applying batch transform if set"""
model.train()
if self.batch_transform is not None:
# Shouldbe: (B, C, T, H, W)
# Incoming: (B, T, C, H, W)
pixel_values, inputs['labels'] = self.batch_transform(
inputs['pixel_values'].permute(0, 2, 1, 3, 4),
inputs['labels']
)
# Have to set it back to the original shape
inputs['pixel_values'] = pixel_values.permute(0, 2, 1, 3, 4)
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
loss.backward()
return loss.detach()
def compute_loss(self, model, inputs, return_outputs=False):
"""
Computes SoftTargetCrossEntropyLoss when batch_transform is set
"""
if self.batch_transform is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = SoftTargetCrossEntropyLoss()(outputs.logits, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def main(
root='./',
seed=1337,
batch_size=8,
num_epochs=40,
model_name='MCG-NJU/videomae-base',
output_dir='videomae-base-finetuned-ucf101-nomixup',
do_train=True,
overwrite_output_dir=False,
resume_from_checkpoint='/home/paperspace/Documents/video-classification/videomae-base-finetuned-ucf101/checkpoint-1500', # None
use_batch_transform=True,
):
videos_path = Path(root) / 'UCF-101'
annotations_path = Path(root) / 'ucfTrainTestlist'
train_annotations_path = annotations_path / 'trainlist01.txt'
test_annotations_path = annotations_path / 'testlist01.txt'
classes = sorted(x.name for x in videos_path.glob("*") if x.is_dir())
label2id = {label: i for i, label in enumerate(classes)}
id2label = {i: label for i, label in enumerate(classes)}
# Detecting last checkpoint.
last_checkpoint = None
if Path(output_dir).is_dir() and do_train and not overwrite_output_dir:
last_checkpoint = get_last_checkpoint(output_dir)
if last_checkpoint is None and len(os.listdir(output_dir)) > 0:
raise ValueError(
f"Output directory ({output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
set_seed(seed)
feature_extractor = VideoMAEFeatureExtractor.from_pretrained(model_name)
model = VideoMAEForVideoClassification.from_pretrained(
model_name,
num_labels=len(classes),
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True,
)
mean = feature_extractor.image_mean
std = feature_extractor.image_std
resize_to = feature_extractor.size
num_frames_to_sample = model.config.num_frames
frame_sample_rate = 4
clip_duration = num_frames_to_sample * frame_sample_rate / 30.0
######################################################################
# Prepare training and validation labeled video paths
######################################################################
train_paths = LabeledVideoPaths.from_csv(train_annotations_path)
# Ugh the official annotations are indexed at 1 instead of 0 so gotta subtract
train_paths._paths_and_labels = [
(path, label - 1) for path, label in train_paths._paths_and_labels
]
# Annotations don't include labels for test set even though they're labeled.
# Adding them here so we can use the test set for validation.
val_paths = LabeledVideoPaths.from_csv(test_annotations_path)
val_paths._paths_and_labels = [
(path, label2id[path.split('/')[0]]) for (path, label) in val_paths._paths_and_labels
]
train_paths._path_prefix = './UCF-101'
val_paths._path_prefix = './UCF-101'
######################################################################
# Prepare training and validation transforms
######################################################################
train_transform = Compose(
[
ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(resize_to),
RandomHorizontalFlip(p=0.5),
]
),
),
]
)
val_transform = Compose(
[
ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
Resize((resize_to, resize_to)),
]
),
),
]
)
train_dataset = LimitDataset(
LabeledVideoDataset(
train_paths,
clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
decode_audio=False,
transform=train_transform,
)
)
# TODO - clips should be sampled w/ certain number of samples per video for correct evaluation calculation
# pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler
val_dataset = LimitDataset(
LabeledVideoDataset(
val_paths,
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
decode_audio=False,
transform=val_transform,
)
)
top_1_acc = Accuracy()
top_5_acc = Accuracy(top_k=2)
def compute_metrics(eval_pred):
"""Computes accuracy on a batch of predictions"""
preds, label_ids = torch.tensor(eval_pred.predictions), torch.tensor(eval_pred.label_ids)
return {
'eval_accuracy': top_1_acc(preds, label_ids),
'eval_accuracy_top5': top_5_acc(preds, label_ids)
}
def collate_fn(examples):
# permute to (num_frames, num_channels, height, width)
# TODO if we are doing batch transform, maybe dont do this permute?
pixel_values = torch.stack(
[example["video"].permute(1, 0, 2, 3) for example in examples]
)
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
# https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/transforms/mix.py
batch_transform = MixVideo(
mixup_alpha=0.8,
cutmix_prob=0.5,
cutmix_alpha=1.0,
label_smoothing=0.1,
num_classes=len(classes),
) if use_batch_transform else None
args = TrainingArguments(
output_dir,
remove_unused_columns=False,
evaluation_strategy="steps",
eval_steps=500,
save_strategy="steps",
learning_rate=5e-4,
weight_decay=0.05,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=8,
warmup_ratio=0.125,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
num_train_epochs=num_epochs,
seed=seed,
dataloader_num_workers=8, # Play with this to find right value.
dataloader_drop_last=True,
report_to='wandb',
lr_scheduler_type='cosine',
fp16=True, # more memory usage for some reason...same with bf16
)
trainer = VideoClassificationTrainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=feature_extractor,
compute_metrics=compute_metrics,
data_collator=collate_fn,
batch_transform=batch_transform
)
checkpoint = None
if resume_from_checkpoint is not None:
checkpoint = resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model() # Saves the tokenizer too for easy upload
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if __name__ == '__main__':
main()
@nateraw
Copy link
Author

nateraw commented Nov 9, 2022

Oof this hurts.

top_5_acc = Accuracy(top_k=2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment