Last active
February 16, 2022 18:16
-
-
Save vfdev-5/951e46d48edf522400b69ae594715cbc to your computer and use it in GitHub Desktop.
functorch make_functional + grad checks vs pytorch computed grads (torchvision models)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Torch: 1.12.0.dev20220215+cu111 | |
torchvision: 0.13.0.dev20220215+cu111 | |
Functorch: 0.2.0a0+c9d03e8 | |
-- Check fasterrcnn_resnet50_fpn model | |
-- Check fasterrcnn_mobilenet_v3_large_320_fpn model | |
-- Check fasterrcnn_mobilenet_v3_large_fpn model | |
-- Check maskrcnn_resnet50_fpn model | |
-- Check keypointrcnn_resnet50_fpn model | |
-- Check retinanet_resnet50_fpn model | |
-- Check ssd300_vgg16 model | |
-- Check ssdlite320_mobilenet_v3_large model | |
-- Check fcos_resnet50_fpn model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models.detection as models | |
from functorch.version import __version__ as ft_version | |
from functorch import make_functional_with_buffers, make_functional, grad | |
tested_models = [] | |
for model_name in models.__dict__: | |
if model_name.startswith("_") or model_name[0].isupper(): | |
continue | |
if not callable(models.__dict__[model_name]): | |
continue | |
tested_models.append(model_name) | |
def compute_grads(model, images, targets): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
loss_dict = model(images, targets) | |
loss = sum(loss for loss in loss_dict.values()) | |
loss.backward() | |
device = 'cpu' | |
def check_grads_model(model_name, device): | |
batch_size = 8 | |
torch.manual_seed(0) | |
size = (224, 224) | |
model = models.__dict__[model_name](num_classes=10, pretrained=False, pretrained_backbone=False) | |
model = model.to(device) | |
images = [torch.rand(3, 224, 224) for _ in range(4)] | |
targets = [ | |
{ | |
"boxes": torch.tensor([[10 + i, 10 + i, 20 + i, 20 + i], [20 + i, 20 + i, 30 + i, 30 + i]]), | |
"labels": torch.tensor([(1 + i) % 10, (2 + i) % 10]), | |
"keypoints": torch.rand(2, 12, 3), | |
"masks": torch.randint(0, 1, size=(2, 224, 224), dtype=torch.uint8), | |
} | |
for i in range(4) | |
] | |
has_buffers = len(list(model.buffers())) > 0 | |
gen_make_functional_fn = None | |
if has_buffers: | |
gen_make_functional_fn = make_functional_with_buffers | |
else: | |
gen_make_functional_fn = make_functional | |
output = gen_make_functional_fn(model) | |
if has_buffers: | |
func_model, weights, buffers = output | |
else: | |
func_model, weights = output | |
buffers = None | |
def compute_loss_ft(weights, buffers, images, targets): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
if buffers is None: | |
loss_dict = func_model(weights, images, targets) | |
else: | |
loss_dict = func_model(weights, buffers, images, targets) | |
loss = sum(loss for loss in loss_dict.values()) | |
return loss | |
compute_grad = grad(compute_loss_ft) | |
w_grads = compute_grad(weights, buffers, images, targets) | |
compute_grads(model, images, targets) | |
assert len(w_grads) == len(list(model.parameters())) | |
for wg, (n, p) in zip(w_grads, model.named_parameters()): | |
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}" | |
print("") | |
print("Torch:", torch.__version__) | |
print("torchvision:", torchvision.__version__) | |
print("Functorch:", ft_version) | |
print("") | |
for model_name in tested_models: | |
print(f"-- Check {model_name} model") | |
try: | |
check_grads_model(model_name, device=device) | |
except AssertionError as e: | |
print(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models.segmentation as models | |
from functorch.version import __version__ as ft_version | |
from functorch import make_functional_with_buffers, make_functional, grad | |
tested_models = [] | |
for model_name in models.__dict__: | |
if model_name.startswith("_") or model_name[0].isupper(): | |
continue | |
if not callable(models.__dict__[model_name]): | |
continue | |
tested_models.append(model_name) | |
criterion = nn.CrossEntropyLoss() | |
def compute_grads(model, image, target): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
output = model(image) | |
loss = criterion(output["out"], target) | |
loss.backward() | |
device = 'cpu' | |
def check_grads_model(model_name, device): | |
batch_size = 8 | |
torch.manual_seed(0) | |
size = (224, 224) | |
model = models.__dict__[model_name](num_classes=10, pretrained=False, pretrained_backbone=False) | |
model = model.to(device) | |
images = torch.randn(batch_size, 3, *size, device=device) | |
targets = torch.randint(0, 10, (batch_size, ) + size, device=device) | |
has_buffers = len(list(model.buffers())) > 0 | |
gen_make_functional_fn = None | |
if has_buffers: | |
gen_make_functional_fn = make_functional_with_buffers | |
else: | |
gen_make_functional_fn = make_functional | |
output = gen_make_functional_fn(model) | |
if has_buffers: | |
func_model, weights, buffers = output | |
else: | |
func_model, weights = output | |
buffers = None | |
def compute_loss_ft(weights, buffers, image, target): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
if buffers is None: | |
output = func_model(weights, image) | |
else: | |
output = func_model(weights, buffers, image) | |
loss = criterion(output["out"], target) | |
return loss | |
compute_grad = grad(compute_loss_ft) | |
w_grads = compute_grad(weights, buffers, images, targets) | |
compute_grads(model, images, targets) | |
assert len(w_grads) == len(list(model.parameters())) | |
for wg, (n, p) in zip(w_grads, model.named_parameters()): | |
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}" | |
print("") | |
print("Torch:", torch.__version__) | |
print("torchvision:", torchvision.__version__) | |
print("Functorch:", ft_version) | |
print("") | |
for model_name in tested_models: | |
print(f"-- Check {model_name} model") | |
try: | |
check_grads_model(model_name, device=device) | |
except AssertionError as e: | |
print(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models as models | |
from functorch.version import __version__ as ft_version | |
from functorch import make_functional_with_buffers, make_functional, grad | |
tested_models = [] | |
for model_name in models.__dict__: | |
if model_name.startswith("_") or model_name[0].isupper(): | |
continue | |
if model_name in ["segmentation", "detection", "video", "quantization", "feature_extraction"]: | |
continue | |
if not callable(models.__dict__[model_name]): | |
continue | |
tested_models.append(model_name) | |
criterion = nn.CrossEntropyLoss() | |
def compute_grads(model, image, target): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
output = model(image) | |
loss = criterion(output, target) | |
loss.backward() | |
device = 'cpu' | |
def check_grads_model(model_name, device): | |
batch_size = 8 | |
torch.manual_seed(0) | |
if model_name == "inception_v3": | |
size = (299, 299) | |
kwargs = {"aux_logits": False} | |
elif model_name == "googlenet": | |
size = (224, 224) | |
kwargs = {"aux_logits": False} | |
else: | |
size = (224, 224) | |
kwargs = {} | |
model = models.__dict__[model_name](num_classes=10, **kwargs) | |
model = model.to(device) | |
images = torch.randn(batch_size, 3, *size, device=device) | |
targets = torch.randint(0, 10, (batch_size,), device=device) | |
has_buffers = len(list(model.buffers())) > 0 | |
gen_make_functional_fn = None | |
if has_buffers: | |
gen_make_functional_fn = make_functional_with_buffers | |
else: | |
gen_make_functional_fn = make_functional | |
output = gen_make_functional_fn(model) | |
if has_buffers: | |
func_model, weights, buffers = output | |
else: | |
func_model, weights = output | |
buffers = None | |
def compute_loss_ft(weights, buffers, image, target): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
if buffers is None: | |
output = func_model(weights, image) | |
else: | |
output = func_model(weights, buffers, image) | |
loss = criterion(output, target) | |
return loss | |
compute_grad = grad(compute_loss_ft) | |
w_grads = compute_grad(weights, buffers, images, targets) | |
compute_grads(model, images, targets) | |
assert len(w_grads) == len(list(model.parameters())) | |
for wg, (n, p) in zip(w_grads, model.named_parameters()): | |
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}" | |
print("") | |
print("Torch:", torch.__version__) | |
print("torchvision:", torchvision.__version__) | |
print("Functorch:", ft_version) | |
print("") | |
for model_name in tested_models: | |
print(f"-- Check {model_name} model") | |
try: | |
check_grads_model(model_name, device=device) | |
except AssertionError as e: | |
print(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Torch: 1.12.0.dev20220209+cu111 | |
torchvision: 0.13.0.dev20220209+cu111 | |
Functorch: 0.2.0a0+da6ec37 | |
-- Check alexnet model | |
-- Check convnext_tiny model | |
-- Check convnext_small model | |
-- Check convnext_base model | |
-- Check convnext_large model | |
-- Check resnet18 model | |
-- Check resnet34 model | |
-- Check resnet50 model | |
-- Check resnet101 model | |
-- Check resnet152 model | |
-- Check resnext50_32x4d model | |
-- Check resnext101_32x8d model | |
-- Check wide_resnet50_2 model | |
-- Check wide_resnet101_2 model | |
-- Check vgg11 model | |
-- Check vgg11_bn model | |
-- Check vgg13 model | |
-- Check vgg13_bn model | |
-- Check vgg16 model | |
-- Check vgg16_bn model | |
-- Check vgg19_bn model | |
-- Check vgg19 model | |
-- Check squeezenet1_0 model | |
-- Check squeezenet1_1 model | |
-- Check inception_v3 model | |
-- Check densenet121 model | |
-- Check densenet169 model | |
-- Check densenet201 model | |
-- Check densenet161 model | |
-- Check googlenet model | |
-- Check mobilenet_v2 model | |
-- Check mobilenet_v3_large model | |
-- Check mobilenet_v3_small model | |
-- Check mnasnet0_5 model | |
-- Check mnasnet0_75 model | |
-- Check mnasnet1_0 model | |
-- Check mnasnet1_3 model | |
-- Check shufflenet_v2_x0_5 model | |
-- Check shufflenet_v2_x1_0 model | |
-- Check shufflenet_v2_x1_5 model | |
-- Check shufflenet_v2_x2_0 model | |
-- Check efficientnet_b0 model | |
grad mismatch for features.1.0.block.0.0.weight: -0.12117128819227219 vs -0.12117135524749756 | |
-- Check efficientnet_b1 model | |
grad mismatch for features.0.0.weight: -0.008784545585513115 vs -0.008784502744674683 | |
-- Check efficientnet_b2 model | |
grad mismatch for features.0.0.weight: -0.17733778059482574 vs -0.17733803391456604 | |
-- Check efficientnet_b3 model | |
grad mismatch for features.0.0.weight: 0.056723516434431076 vs 0.05672360211610794 | |
-- Check efficientnet_b4 model | |
grad mismatch for features.0.0.weight: -0.008806591853499413 vs -0.008806349709630013 | |
-- Check efficientnet_b5 model | |
grad mismatch for features.0.0.weight: 0.18076033890247345 vs 0.1807602196931839 | |
-- Check efficientnet_b6 model | |
grad mismatch for features.0.0.weight: 0.1274760216474533 vs 0.12747612595558167 | |
-- Check efficientnet_b7 model | |
grad mismatch for features.0.0.weight: -0.030335871502757072 vs -0.030335767194628716 | |
-- Check regnet_y_400mf model | |
-- Check regnet_y_800mf model | |
-- Check regnet_y_1_6gf model | |
-- Check regnet_y_3_2gf model | |
-- Check regnet_y_8gf model | |
-- Check regnet_y_16gf model | |
-- Check regnet_y_32gf model | |
-- Check regnet_y_128gf model | |
-- Check regnet_x_400mf model | |
-- Check regnet_x_800mf model | |
-- Check regnet_x_1_6gf model | |
-- Check regnet_x_3_2gf model | |
-- Check regnet_x_8gf model | |
-- Check regnet_x_16gf model | |
-- Check regnet_x_32gf model | |
-- Check vit_b_16 model | |
-- Check vit_b_32 model | |
-- Check vit_l_16 model | |
-- Check vit_l_32 model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Torch: 1.12.0.dev20220209+cu111 | |
torchvision: 0.13.0.dev20220209+cu111 | |
Functorch: 0.2.0a0+da6ec37 | |
-- Check fcn_resnet50 model | |
-- Check fcn_resnet101 model | |
-- Check deeplabv3_resnet50 model | |
-- Check deeplabv3_resnet101 model | |
-- Check deeplabv3_mobilenet_v3_large model | |
-- Check lraspp_mobilenet_v3_large model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment