Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active February 17, 2022 18:04
Show Gist options
  • Save vfdev-5/7d039c2d09479da5f788ae012b7067cb to your computer and use it in GitHub Desktop.
Save vfdev-5/7d039c2d09479da5f788ae012b7067cb to your computer and use it in GitHub Desktop.
functorch per-sample grads checks vs pytorch (torchvision models)
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from functorch.version import __version__ as ft_version
from functorch import make_functional_with_buffers, grad, vmap
tested_models = []
for model_name in models.__dict__:
if model_name.startswith("_") or model_name[0].isupper():
continue
if model_name in ["segmentation", "detection", "video", "quantization", "feature_extraction"]:
continue
if not callable(models.__dict__[model_name]):
continue
tested_models.append(model_name)
criterion = nn.CrossEntropyLoss(reduction="sum")
def compute_grads(model, image, target):
image = image.unsqueeze(0)
target = target.unsqueeze(0)
# Fix seed to fix dropout
torch.manual_seed(0)
output = model(image)
loss = criterion(output, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(images, targets):
batch_size = images.shape[0]
sample_grads = [compute_grads(images[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
device = 'cpu'
def check_grads_model(model_name, device):
batch_size = 8
torch.manual_seed(0)
if model_name == "inception_v3":
size = (299, 299)
kwargs = {"aux_logits": False}
elif model_name == "googlenet":
size = (224, 224)
kwargs = {"aux_logits": False}
elif "convnext" in model_name:
raise AssertionError("Skip convnext-like models as they contain bernoulli_ random vmap unsupported op")
else:
size = (224, 224)
kwargs = {}
model = models.__dict__[model_name](num_classes=10, **kwargs)
# Temporarily skip models with dropout ops
skip_model = False
def has_dropouts_or_BN(m):
nonlocal skip_model
if isinstance(m, nn.Dropout):
skip_model = True
elif isinstance(m, nn.BatchNorm2d):
skip_model = True
model.apply(has_dropouts_or_BN)
if skip_model:
raise AssertionError("Model has dropout or batchnorm -> skip it")
model = model.to(device)
images = torch.randn(batch_size, 3, *size, device=device)
targets = torch.randint(0, 10, (batch_size,), device=device)
func_model, weights, buffers = make_functional_with_buffers(model)
def compute_loss_ft(weights, buffers, image, target):
image = image.unsqueeze(0)
target = target.unsqueeze(0)
# Fix seed to fix dropout
torch.manual_seed(0)
output = func_model(weights, buffers, image)
loss = criterion(output, target)
return loss
ft_compute_grads = grad(compute_loss_ft)
ft_compute_sample_grad = vmap(ft_compute_grads, in_dims=(None, None, 0, 0))
ft_per_sample_grads = ft_compute_sample_grad(weights, buffers, images, targets)
per_sample_grads = compute_grads(model, images, targets)
assert len(per_sample_grads) == len(ft_per_sample_grads)
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-3, rtol=1e-5)
print("")
print("Torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("Functorch:", ft_version)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_grads_model(model_name, device=device)
except AssertionError as e:
print(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment