Created
April 2, 2020 21:18
-
-
Save shashankprasanna/245f5d3fe2116f62d2432071a75c285f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from smdebug.rules import Rule | |
class CustomGradientRule(Rule): | |
def __init__(self, base_trial, threshold=10.0): | |
super().__init__(base_trial) | |
self.threshold = float(threshold) | |
def invoke_at_step(self, step): | |
for tname in self.base_trial.tensor_names(collection="gradients"): | |
t = self.base_trial.tensor(tname) | |
abs_mean = t.reduction_value(step, "mean", abs=True) | |
if abs_mean > self.threshold: | |
return True | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment