This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| require 'nn' | |
| require 'stn' -- github.com/qassemoquab/stnbhwd by Maxime Oquab | |
| local localization_network = torch.load('your_locnet.t7') | |
| local ct = nn.ConcatTable() | |
| local branch1 = nn.Transpose({3,4},{2,4}) | |
| local branch2 = nn.Sequential() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| function networks.convs_noutput(convs, input_size) | |
| input_size = input_size or networks.base_input_size | |
| -- Get the number of channels for conv that are multiscale or not | |
| local nbr_input_channels = convs[1]:get(1).nInputPlane or | |
| convs[1]:get(1):get(1).nInputPlane | |
| local output = torch.Tensor(1, nbr_input_channels, input_size, input_size) | |
| for _, conv in ipairs(convs) do | |
| output = conv:forward(output) | |
| end | |
| return output:nElement(), output:size(3) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| require 'nn' | |
| require 'stn' | |
| ------ | |
| -- Prepare your localization network | |
| local localization_network = torch.load('your_locnet.t7') | |
| ------ | |
| -- prepare both branches of the st | |
| local ct = nn.ConcatTable() |
NewerOlder