Last active
May 15, 2019 18:46
-
-
Save leVirve/0377a8fbac455bfd44e374e5cf8b1260 to your computer and use it in GitHub Desktop.
The real CoordConv in PyTorch. It can auto-infer the x-y dimensions in tensors. Use it without pain. 💜
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AddCoords(nn.Module): | |
def __init__(self, with_r=False): | |
super().__init__() | |
self.with_r = with_r | |
def forward(self, input_tensor): | |
""" | |
Args: | |
input_tensor: shape(batch, channel, x_dim, y_dim) | |
""" | |
batch_size, _, x_dim, y_dim = input_tensor.size() | |
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) | |
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) | |
xx_channel = xx_channel.float() / (x_dim - 1) | |
yy_channel = yy_channel.float() / (y_dim - 1) | |
xx_channel = xx_channel * 2 - 1 | |
yy_channel = yy_channel * 2 - 1 | |
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) | |
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) | |
ret = torch.cat([ | |
input_tensor, | |
xx_channel.type_as(input_tensor), | |
yy_channel.type_as(input_tensor)], dim=1) | |
if self.with_r: | |
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) | |
ret = torch.cat([ret, rr], dim=1) | |
return ret | |
class CoordConv(nn.Module): | |
def __init__(self, in_channels, out_channels, with_r=False, **kwargs): | |
super().__init__() | |
self.addcoords = AddCoords(with_r=with_r) | |
self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs) | |
def forward(self, x): | |
ret = self.addcoords(x) | |
ret = self.conv(ret) | |
return ret |
@mkocabas Sure! But this is an alternative implementation.
I'm looking into the author's version and will make a pull request for your project. 😄
@leVirve This is great!
FYI there was one problem I ran into when running multiple experiments with multiple gpus. I kept getting out of memory errors on line 28 xx_channel.type_as(input_tensor),
even with small batch sizes. It looks like I hit the same issue as pytorch/pytorch#3477.
The solution that worked for me was wrapping AddCoords.forward with with torch.cuda.device_of(input_tensor):
but maybe that with should wrap the entire trained network code... anyhow hope this helps anyone else who runs into this issue.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @leVirve thanks for the implementation! Can I update my repo according to yours? Or if you want, you can create a pull request to my repo.