Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active June 14, 2021 17:19
Show Gist options
  • Save Mason-McGough/5e19464613275ca7ea111c389c946011 to your computer and use it in GitHub Desktop.
Save Mason-McGough/5e19464613275ca7ea111c389c946011 to your computer and use it in GitHub Desktop.
Simple gradient tracker for PyTorch
class GradTracker():
def __init__(
self,
model: nn.Module
):
self.model = model
self.reset_history()
self.reset_buffer()
def __getitem__(
self,
key: str
) -> float:
return self.grad_mags[key]
def __str__(self) -> str:
return str(self.grad_mags)
def reset_history(self):
self.grad_mags = {}
for name, param in self.model.named_parameters():
if param.requires_grad:
self.grad_mags[name] = []
def reset_buffer(self):
self.grad_buffer = {}
for name, param in self.model.named_parameters():
if param.requires_grad:
self.grad_buffer[name] = []
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.grad_buffer[name].append(torch.mean(param.grad).item())
def end_epoch(self):
for name, val in self.grad_buffer.items():
self.grad_mags[name].append(self.compute_magnitude(val))
self.reset_buffer()
@staticmethod
def compute_magnitude(val: list):
"""
Calculate RMS magnitude of list
"""
return (sum(map(lambda x: x ** 2, val)) / len(val)) ** 0.5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment