Skip to content

Instantly share code, notes, and snippets.

@shashankprasanna
Created April 2, 2020 21:12
Show Gist options
  • Save shashankprasanna/f5206654754735f176111350dcd81acc to your computer and use it in GitHub Desktop.
Save shashankprasanna/f5206654754735f176111350dcd81acc to your computer and use it in GitHub Desktop.
import smdebug.pytorch as smd
net = get_network()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
job_name = 'pytorch-debug-job'
hook = smd.Hook(out_dir=f'./smd_outputs/{job_name}',
save_config=smd.SaveConfig(save_interval=10),
include_collections=['gradients', 'biases'])
hook.register_module(net)
hook.register_loss(criterion)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment