Skip to content

Instantly share code, notes, and snippets.

@shashankprasanna
Created April 2, 2020 21:18
Show Gist options
  • Save shashankprasanna/245f5d3fe2116f62d2432071a75c285f to your computer and use it in GitHub Desktop.
Save shashankprasanna/245f5d3fe2116f62d2432071a75c285f to your computer and use it in GitHub Desktop.
from smdebug.rules import Rule
class CustomGradientRule(Rule):
def __init__(self, base_trial, threshold=10.0):
super().__init__(base_trial)
self.threshold = float(threshold)
def invoke_at_step(self, step):
for tname in self.base_trial.tensor_names(collection="gradients"):
t = self.base_trial.tensor(tname)
abs_mean = t.reduction_value(step, "mean", abs=True)
if abs_mean > self.threshold:
return True
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment