Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Created February 16, 2022 18:17
Show Gist options
  • Save vfdev-5/62eab9500daa2cf29554a6128edbaa12 to your computer and use it in GitHub Desktop.
Save vfdev-5/62eab9500daa2cf29554a6128edbaa12 to your computer and use it in GitHub Desktop.
functorch make_functional + grad checks vs pytorch computed grads (NLP HF transformers)
Torch: 1.12.0.dev20220215+cu111
transformers: 4.16.2
Functorch: 0.2.0a0+c9d03e8
-- Check bert-base-cased model
-- Check gpt2 model
-- Check facebook/bart-large model
import torch
import torch.nn as nn
from functorch.version import __version__ as ft_version
from functorch import make_functional_with_buffers, make_functional, grad
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig
tested_models = [
"bert-base-cased",
"gpt2",
"facebook/bart-large",
]
def compute_grads(model, inpts):
# Fix seed to fix dropout
torch.manual_seed(0)
outputs = model(**inpts)
out = outputs.last_hidden_state
out.mean().backward()
device = 'cpu'
def check_grads_model(model_name, device):
torch.manual_seed(0)
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModel.from_config(config)
model = model.to(device)
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
has_buffers = len(list(model.buffers())) > 0
gen_make_functional_fn = None
if has_buffers:
gen_make_functional_fn = make_functional_with_buffers
else:
gen_make_functional_fn = make_functional
output = gen_make_functional_fn(model)
if has_buffers:
func_model, weights, buffers = output
else:
func_model, weights = output
buffers = None
def compute_loss_ft(weights, buffers, inputs):
# Fix seed to fix dropout
torch.manual_seed(0)
if buffers is None:
output = func_model(weights, **inputs)
else:
output = func_model(weights, buffers, **inputs)
return output.last_hidden_state.mean()
compute_grad = grad(compute_loss_ft)
w_grads = compute_grad(weights, buffers, inputs)
compute_grads(model, inputs)
assert len(w_grads) == len(list(model.parameters()))
for wg, (n, p) in zip(w_grads, model.named_parameters()):
if p.grad is None:
continue
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}"
# print(p.grad.allclose(wg, atol=1e-5), p.grad.shape, wg.shape, f"grads for {n}: {p.grad.mean()} vs {wg.mean()}")
print("")
print("Torch:", torch.__version__)
print("transformers:", transformers.__version__)
print("Functorch:", ft_version)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_grads_model(model_name, device=device)
except AssertionError as e:
print(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment