Created
February 16, 2022 18:17
-
-
Save vfdev-5/62eab9500daa2cf29554a6128edbaa12 to your computer and use it in GitHub Desktop.
functorch make_functional + grad checks vs pytorch computed grads (NLP HF transformers)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Torch: 1.12.0.dev20220215+cu111 | |
transformers: 4.16.2 | |
Functorch: 0.2.0a0+c9d03e8 | |
-- Check bert-base-cased model | |
-- Check gpt2 model | |
-- Check facebook/bart-large model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from functorch.version import __version__ as ft_version | |
from functorch import make_functional_with_buffers, make_functional, grad | |
import transformers | |
from transformers import AutoModel, AutoTokenizer, AutoConfig | |
tested_models = [ | |
"bert-base-cased", | |
"gpt2", | |
"facebook/bart-large", | |
] | |
def compute_grads(model, inpts): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
outputs = model(**inpts) | |
out = outputs.last_hidden_state | |
out.mean().backward() | |
device = 'cpu' | |
def check_grads_model(model_name, device): | |
torch.manual_seed(0) | |
config = AutoConfig.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
model = AutoModel.from_config(config) | |
model = model.to(device) | |
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |
has_buffers = len(list(model.buffers())) > 0 | |
gen_make_functional_fn = None | |
if has_buffers: | |
gen_make_functional_fn = make_functional_with_buffers | |
else: | |
gen_make_functional_fn = make_functional | |
output = gen_make_functional_fn(model) | |
if has_buffers: | |
func_model, weights, buffers = output | |
else: | |
func_model, weights = output | |
buffers = None | |
def compute_loss_ft(weights, buffers, inputs): | |
# Fix seed to fix dropout | |
torch.manual_seed(0) | |
if buffers is None: | |
output = func_model(weights, **inputs) | |
else: | |
output = func_model(weights, buffers, **inputs) | |
return output.last_hidden_state.mean() | |
compute_grad = grad(compute_loss_ft) | |
w_grads = compute_grad(weights, buffers, inputs) | |
compute_grads(model, inputs) | |
assert len(w_grads) == len(list(model.parameters())) | |
for wg, (n, p) in zip(w_grads, model.named_parameters()): | |
if p.grad is None: | |
continue | |
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}" | |
# print(p.grad.allclose(wg, atol=1e-5), p.grad.shape, wg.shape, f"grads for {n}: {p.grad.mean()} vs {wg.mean()}") | |
print("") | |
print("Torch:", torch.__version__) | |
print("transformers:", transformers.__version__) | |
print("Functorch:", ft_version) | |
print("") | |
for model_name in tested_models: | |
print(f"-- Check {model_name} model") | |
try: | |
check_grads_model(model_name, device=device) | |
except AssertionError as e: | |
print(e) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment