Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Last active December 23, 2018 22:39
Show Gist options
  • Save mrdrozdov/7440ddc0362b9c61cbffd094afac8a2a to your computer and use it in GitHub Desktop.
Save mrdrozdov/7440ddc0362b9c61cbffd094afac8a2a to your computer and use it in GitHub Desktop.
pt.py
class Trainer(self):
def __init__(self, net, optimizer, ngpus=1):
self.net = net
self.optimizer = optimizer
self.ngpus = ngpus
def step(self, batch, train, ...):
"""
Alternatively, you can avoid replicating your model each batch.
This is particularly useful if your model has any type of state.
if self.replicas is None:
self.replicas = nn.parallel.replicate(self.net, device_ids)
args, kwargs = torch.nn.parallel.scatter_gather.scatter_kwargs(args, kwargs, device_ids, 0)
outputs = torch.nn.parallel.parallel_apply(self.replicas, args, kwargs, device_ids)
out = torch.nn.parallel.scatter_gather.gather(outputs, 0, 0)
Note that you might need to scale the number of workers for your DataLoader
up with the # of gpus.
"""
if ngpus > 1:
# If the output of your net is shape (A, B, C, ...), then the
# output of parallel will be shape (NxA, NxB, NxC, ...) where
# N is the # of gpus.
device_ids = range(self.ngpus)
out = torch.nn.parallel.data_parallel(
self.net, args, device_ids,
module_kwargs=kwargs)
else:
out = self.net(*args, **kwargs)
loss = self.compute_loss(batch, out)
if train:
self.update(loss)
def update(self, loss):
self.optimizer.zero_grad()
loss.backward()
# gradient clip, etc.
self.optimizer.step()
def epoch(self, loader):
self.net.train()
for batch in loader:
self.step(batch, train=True)
def test(self, loader):
self.net.eval()
for batch in loader:
self.step(batch, train=False)
class Net(nn.Module):
def __init__(self, src):
super(Net, self).__init__()
self.src = src
self.a = a
self.conv1 = nn.Conv1d(1, 2, 3, 4,..)
def _fwd_1(self, x):
x = self.conv1(x)
x = x + self.a
def _fwd_2(self, x):
# sth similar to _fwd_1
def forward(self, x):
x = self._fwd_1(x)
return self._fwd_2(x)
class NetWrapper(nn.Module):
"""
If you want to run a custom (i.e. non-forward) method
for your model, then the easiest is to wrap the model
in another module and call the desired method.
Alternatively, `parallel_apply` is concise multi-threaded
code written in python, and you could create a modified
version that calls a desired method instead of forward.
https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py#L21
You could try "monkey-patching" forward whenever you need to change
the behavior, but it seems a little much.
def customforward(self, x):
return blah
net.forward = types.MethodType(customforward, net)
"""
def __init__(self, net):
super(NetWrapper, self).__init__()
self.net = net
def forward(self, x):
return self.net._fwd_1(x)
net = Net(a, b, ..)
train_loader = DataLoader(workers=workers*ngpus, collate_fn=my_collate_fn, ...)
test_loader = DataLoader(workers=workers*ngpus, collate_fn=my_collate_fn, ...)
trainer = Trainer(net, optimizer, ngpus)
# trainer = Trainer(NetWrapper(net), optimizer, ngpus)
for i in range(nepochs):
trainer.epoch(train_loader)
trainer.test(test_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment