Last active
June 14, 2021 17:19
-
-
Save Mason-McGough/5e19464613275ca7ea111c389c946011 to your computer and use it in GitHub Desktop.
Simple gradient tracker for PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GradTracker(): | |
def __init__( | |
self, | |
model: nn.Module | |
): | |
self.model = model | |
self.reset_history() | |
self.reset_buffer() | |
def __getitem__( | |
self, | |
key: str | |
) -> float: | |
return self.grad_mags[key] | |
def __str__(self) -> str: | |
return str(self.grad_mags) | |
def reset_history(self): | |
self.grad_mags = {} | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad: | |
self.grad_mags[name] = [] | |
def reset_buffer(self): | |
self.grad_buffer = {} | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad: | |
self.grad_buffer[name] = [] | |
def update(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad: | |
self.grad_buffer[name].append(torch.mean(param.grad).item()) | |
def end_epoch(self): | |
for name, val in self.grad_buffer.items(): | |
self.grad_mags[name].append(self.compute_magnitude(val)) | |
self.reset_buffer() | |
@staticmethod | |
def compute_magnitude(val: list): | |
""" | |
Calculate RMS magnitude of list | |
""" | |
return (sum(map(lambda x: x ** 2, val)) / len(val)) ** 0.5 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment