| Class | Pytorch | MXNet Gluon |
|---|---|---|
| Dataset holding arrays | torch.utils.data.TensorDataset(data_tensor, label_tensor) |
gluon.data.ArrayDataset(data_array, label_array) |
| Data loader | torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, drop_last=False) |
gluon.data.DataLoader(dataset, batch_size=None, shuffle=False, sampler=None, last_batch='keep'(discard, rollover), batch_sampler=None, batchify_fn=None, num_workers=0) |
| Sequentially applied sampler | torch.utils.data.sampler.SequentialSampler(data_source) |
gluon.data.SequentialSampler(length) |
| Random order sampler | torch.utils.data.sampler.RandomSampler(data_source) |
gluon.data.RandomSampler(length) |
Last active
November 26, 2017 05:04
-
-
Save zhreshold/9e567ebd8f756eac9a3bd5c87e3024e1 to your computer and use it in GitHub Desktop.
PyTorch to Gluon cheatsheet
| Class | Pytorch | MXNet Gluon |
|---|---|---|
| Save model parameters | torch.save(the_model.state_dict(), filename) |
model.save_params(filename) |
| Load parameters | the_model.load_state_dict(torch.load(PATH)) |
model.load_params(filename, ctx, allow_missing=False, ignore_extra=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment