Last active
February 17, 2022 18:04
-
-
Save vfdev-5/7d039c2d09479da5f788ae012b7067cb to your computer and use it in GitHub Desktop.
functorch per-sample grads checks vs pytorch (torchvision models)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models as models | |
from functorch.version import __version__ as ft_version | |
from functorch import make_functional_with_buffers, grad, vmap | |
tested_models = [] | |
for model_name in models.__dict__: | |
if model_name.startswith("_") or model_name[0].isupper(): | |
continue | |
if model_name in ["segmentation", "detection", "video", "quantization", "feature_extraction"]: | |
continue | |
if not callable(models.__dict__[model_name]): | |
continue | |
tested_models.append(model_name) | |
criterion = nn.CrossEntropyLoss(reduction="sum") | |
def compute_grads(model, image, target): | |
image = image.unsqueeze(0) | |
target = target.unsqueeze(0) | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
output = model(image) | |
loss = criterion(output, target) | |
return torch.autograd.grad(loss, list(model.parameters())) | |
def compute_sample_grads(images, targets): | |
batch_size = images.shape[0] | |
sample_grads = [compute_grads(images[i], targets[i]) for i in range(batch_size)] | |
sample_grads = zip(*sample_grads) | |
sample_grads = [torch.stack(shards) for shards in sample_grads] | |
return sample_grads | |
device = 'cpu' | |
def check_grads_model(model_name, device): | |
batch_size = 8 | |
torch.manual_seed(0) | |
if model_name == "inception_v3": | |
size = (299, 299) | |
kwargs = {"aux_logits": False} | |
elif model_name == "googlenet": | |
size = (224, 224) | |
kwargs = {"aux_logits": False} | |
elif "convnext" in model_name: | |
raise AssertionError("Skip convnext-like models as they contain bernoulli_ random vmap unsupported op") | |
else: | |
size = (224, 224) | |
kwargs = {} | |
model = models.__dict__[model_name](num_classes=10, **kwargs) | |
# Temporarily skip models with dropout ops | |
skip_model = False | |
def has_dropouts_or_BN(m): | |
nonlocal skip_model | |
if isinstance(m, nn.Dropout): | |
skip_model = True | |
elif isinstance(m, nn.BatchNorm2d): | |
skip_model = True | |
model.apply(has_dropouts_or_BN) | |
if skip_model: | |
raise AssertionError("Model has dropout or batchnorm -> skip it") | |
model = model.to(device) | |
images = torch.randn(batch_size, 3, *size, device=device) | |
targets = torch.randint(0, 10, (batch_size,), device=device) | |
func_model, weights, buffers = make_functional_with_buffers(model) | |
def compute_loss_ft(weights, buffers, image, target): | |
image = image.unsqueeze(0) | |
target = target.unsqueeze(0) | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
output = func_model(weights, buffers, image) | |
loss = criterion(output, target) | |
return loss | |
ft_compute_grads = grad(compute_loss_ft) | |
ft_compute_sample_grad = vmap(ft_compute_grads, in_dims=(None, None, 0, 0)) | |
ft_per_sample_grads = ft_compute_sample_grad(weights, buffers, images, targets) | |
per_sample_grads = compute_grads(model, images, targets) | |
assert len(per_sample_grads) == len(ft_per_sample_grads) | |
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads): | |
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-3, rtol=1e-5) | |
print("") | |
print("Torch:", torch.__version__) | |
print("torchvision:", torchvision.__version__) | |
print("Functorch:", ft_version) | |
print("") | |
for model_name in tested_models: | |
print(f"-- Check {model_name} model") | |
try: | |
check_grads_model(model_name, device=device) | |
except AssertionError as e: | |
print(e) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment