Last active
February 5, 2025 15:51
-
Star
(110)
You must be signed in to star a gist -
Fork
(14)
You must be signed in to fork a gist
-
-
Save apaszke/f93a377244be9bfcb96d3547b9bc424d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graphviz import Digraph | |
import torch | |
from torch.autograd import Variable, Function | |
def iter_graph(root, callback): | |
queue = [root] | |
seen = set() | |
while queue: | |
fn = queue.pop() | |
if fn in seen: | |
continue | |
seen.add(fn) | |
for next_fn, _ in fn.next_functions: | |
if next_fn is not None: | |
queue.append(next_fn) | |
callback(fn) | |
def register_hooks(var): | |
fn_dict = {} | |
def hook_cb(fn): | |
def register_grad(grad_input, grad_output): | |
fn_dict[fn] = grad_input | |
fn.register_hook(register_grad) | |
iter_graph(var.grad_fn, hook_cb) | |
def is_bad_grad(grad_output): | |
grad_output = grad_output.data | |
return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any() | |
def make_dot(): | |
node_attr = dict(style='filled', | |
shape='box', | |
align='left', | |
fontsize='12', | |
ranksep='0.1', | |
height='0.2') | |
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) | |
def size_to_str(size): | |
return '('+(', ').join(map(str, size))+')' | |
def build_graph(fn): | |
if hasattr(fn, 'variable'): # if GradAccumulator | |
u = fn.variable | |
node_name = 'Variable\n ' + size_to_str(u.size()) | |
dot.node(str(id(u)), node_name, fillcolor='lightblue') | |
else: | |
assert fn in fn_dict, fn | |
fillcolor = 'white' | |
if any(is_bad_grad(gi) for gi in fn_dict[fn]): | |
fillcolor = 'red' | |
dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor) | |
for next_fn, _ in fn.next_functions: | |
if next_fn is not None: | |
next_id = id(getattr(next_fn, 'variable', next_fn)) | |
dot.edge(str(next_id), str(id(fn))) | |
iter_graph(var.grad_fn, build_graph) | |
return dot | |
return make_dot | |
if __name__ == '__main__': | |
x = Variable(torch.randn(10, 10), requires_grad=True) | |
y = Variable(torch.randn(10, 10), requires_grad=True) | |
z = x / (y * 0) | |
z = z.sum() * 2 | |
get_dot = register_hooks(z) | |
z.backward() | |
dot = get_dot() | |
dot.save('tmp.dot') |
probably change is_bad_grad
to
def is_bad_grad(grad_output):
if grad_output is None:
return True
grad_output = grad_output.data
return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()
recent pytorch sets zeros as None sometimes to speedup the process of accumulating gradients. Return true or false, based on whether you consider zero a good or bad gradient.
I run out of ram using this code, and if I try only running it once after n iterations, it crashes.
However, I got some interesting graph before the ram ran out. Does anyone have any suggestion on how to approach the problem? The graph is very very large, and there are red nodes everywhere but this is the end of it:
My model is not very complicated (apart from the transformer itself):
# def __init__
self.bert = transformers.BertModel.from_pretrained(config.BASE_MODEL_PATH, return_dict=False)
self.bert_drop_1 = nn.Dropout(0.3)
self.out_labels = nn.Linear(768, 1)
self.sigmoid = nn.Sigmoid()
#... def forward
o1, o2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
dropout = self.bert_drop_1(o2)
logits = self.out_labels(dropout)
labels = self.sigmoid(logits)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This fails with
on the first line of
is_bad_grad
usingtorch==1.9.1+cu111
, any suggestions on what to change?