Last active
December 23, 2018 22:39
-
-
Save mrdrozdov/7440ddc0362b9c61cbffd094afac8a2a to your computer and use it in GitHub Desktop.
pt.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Trainer(self): | |
def __init__(self, net, optimizer, ngpus=1): | |
self.net = net | |
self.optimizer = optimizer | |
self.ngpus = ngpus | |
def step(self, batch, train, ...): | |
""" | |
Alternatively, you can avoid replicating your model each batch. | |
This is particularly useful if your model has any type of state. | |
if self.replicas is None: | |
self.replicas = nn.parallel.replicate(self.net, device_ids) | |
args, kwargs = torch.nn.parallel.scatter_gather.scatter_kwargs(args, kwargs, device_ids, 0) | |
outputs = torch.nn.parallel.parallel_apply(self.replicas, args, kwargs, device_ids) | |
out = torch.nn.parallel.scatter_gather.gather(outputs, 0, 0) | |
Note that you might need to scale the number of workers for your DataLoader | |
up with the # of gpus. | |
""" | |
if ngpus > 1: | |
# If the output of your net is shape (A, B, C, ...), then the | |
# output of parallel will be shape (NxA, NxB, NxC, ...) where | |
# N is the # of gpus. | |
device_ids = range(self.ngpus) | |
out = torch.nn.parallel.data_parallel( | |
self.net, args, device_ids, | |
module_kwargs=kwargs) | |
else: | |
out = self.net(*args, **kwargs) | |
loss = self.compute_loss(batch, out) | |
if train: | |
self.update(loss) | |
def update(self, loss): | |
self.optimizer.zero_grad() | |
loss.backward() | |
# gradient clip, etc. | |
self.optimizer.step() | |
def epoch(self, loader): | |
self.net.train() | |
for batch in loader: | |
self.step(batch, train=True) | |
def test(self, loader): | |
self.net.eval() | |
for batch in loader: | |
self.step(batch, train=False) | |
class Net(nn.Module): | |
def __init__(self, src): | |
super(Net, self).__init__() | |
self.src = src | |
self.a = a | |
self.conv1 = nn.Conv1d(1, 2, 3, 4,..) | |
def _fwd_1(self, x): | |
x = self.conv1(x) | |
x = x + self.a | |
def _fwd_2(self, x): | |
# sth similar to _fwd_1 | |
def forward(self, x): | |
x = self._fwd_1(x) | |
return self._fwd_2(x) | |
class NetWrapper(nn.Module): | |
""" | |
If you want to run a custom (i.e. non-forward) method | |
for your model, then the easiest is to wrap the model | |
in another module and call the desired method. | |
Alternatively, `parallel_apply` is concise multi-threaded | |
code written in python, and you could create a modified | |
version that calls a desired method instead of forward. | |
https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py#L21 | |
You could try "monkey-patching" forward whenever you need to change | |
the behavior, but it seems a little much. | |
def customforward(self, x): | |
return blah | |
net.forward = types.MethodType(customforward, net) | |
""" | |
def __init__(self, net): | |
super(NetWrapper, self).__init__() | |
self.net = net | |
def forward(self, x): | |
return self.net._fwd_1(x) | |
net = Net(a, b, ..) | |
train_loader = DataLoader(workers=workers*ngpus, collate_fn=my_collate_fn, ...) | |
test_loader = DataLoader(workers=workers*ngpus, collate_fn=my_collate_fn, ...) | |
trainer = Trainer(net, optimizer, ngpus) | |
# trainer = Trainer(NetWrapper(net), optimizer, ngpus) | |
for i in range(nepochs): | |
trainer.epoch(train_loader) | |
trainer.test(test_loader) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment