If nothing is specified, all argument combination should be considered
- copy_ no_sparse && no_quantize && self!=source && not_copy_transpose
- gather
- gather(out=)
- scatter_(Tensor)
- scatter(Tensor)
- scatter_(value)
| import torch | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| from torch.utils._pytree import tree_map_only | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| import time | |
| import warnings | |
| import weakref | |
| import traceback |
| import torch | |
| from torch.autograd import Function | |
| import torch.utils._pytree as pytree | |
| # Basically wraps things in and out before passing it to the real function that the user defined. | |
| def pytreeify(cls): | |
| assert issubclass(cls, Function) | |
| orig_fw = cls.forward | |
| orig_bw = cls.backward |
| # Implements Alban's idea of making available the forward traceback | |
| # corresponding to the execution of the current backwared node as a global | |
| # Updated of https://gist.github.com/soulitzer/28140cc4cd7d26828ff7f07b1235d9f5 | |
| # to add inter op tracking | |
| import torch | |
| from torch import autograd | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| current_metadata = None |
| import torch | |
| from torch import nn | |
| from torch.optim.sgd import sgd | |
| import gc | |
| import objgraph | |
| import weakref | |
| def all(): | |
| # Only a subset of the args you could have | |
| def set_sgd_hook(mod, p, lr, weight_decay, momentum): |
| from patch_convolution import * | |
| import torch | |
| import torch.nn as nn | |
| import time | |
| # --------------- | |
| # Parameters | |
| # --------------- | |
| # Number of profile iterations to run | |
| itt = 30 |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class EasyDataParallel(nn.Module): | |
| def __init__(self, gpus): | |
| super().__init__() | |
| # Handle cpu / 1 gpu case better | |
| assert isinstance(gpus, list) |
| import torch | |
| from torch import nn | |
| from torchviz import make_dot | |
| from torch.autograd.gradcheck import gradcheck | |
| torch.set_default_tensor_type(torch.DoubleTensor) | |
| my_mod = nn.Sequential(nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 1, bias=False)) | |
| params = list(my_mod.parameters()) |
| local threads = require "threads" | |
| threads.Threads.serialization('threads.sharedserialize') | |
| n_task = 3 | |
| local pools = {} | |
| for task=1,n_task do | |
| pools[task] = threads.Threads(5, | |
| function() | |
| -- Needed only for serialized elements |