std::cout << "Forwarding into jit module" << std::endl;
std::cout << "Forward code:" << std::endl;
std::cout << *grad.f.get() << std::endl;
std::cout << "Backward code:" << std::endl;
std::cout << *grad.df.get() << std::endl;
std::cout << "End print !" << std::endl;
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
require 'stn' | |
------ | |
-- Prepare your localization network | |
local localization_network = torch.load('your_locnet.t7') | |
------ | |
-- prepare both branches of the st | |
local ct = nn.ConcatTable() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function networks.convs_noutput(convs, input_size) | |
input_size = input_size or networks.base_input_size | |
-- Get the number of channels for conv that are multiscale or not | |
local nbr_input_channels = convs[1]:get(1).nInputPlane or | |
convs[1]:get(1):get(1).nInputPlane | |
local output = torch.Tensor(1, nbr_input_channels, input_size, input_size) | |
for _, conv in ipairs(convs) do | |
output = conv:forward(output) | |
end | |
return output:nElement(), output:size(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
require 'stn' -- github.com/qassemoquab/stnbhwd by Maxime Oquab | |
local localization_network = torch.load('your_locnet.t7') | |
local ct = nn.ConcatTable() | |
local branch1 = nn.Transpose({3,4},{2,4}) | |
local branch2 = nn.Sequential() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local threads = require "threads" | |
threads.Threads.serialization('threads.sharedserialize') | |
n_task = 3 | |
local pools = {} | |
for task=1,n_task do | |
pools[task] = threads.Threads(5, | |
function() | |
-- Needed only for serialized elements |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torchviz import make_dot | |
from torch.autograd.gradcheck import gradcheck | |
torch.set_default_tensor_type(torch.DoubleTensor) | |
my_mod = nn.Sequential(nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 1, bias=False)) | |
params = list(my_mod.parameters()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class EasyDataParallel(nn.Module): | |
def __init__(self, gpus): | |
super().__init__() | |
# Handle cpu / 1 gpu case better | |
assert isinstance(gpus, list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from patch_convolution import * | |
import torch | |
import torch.nn as nn | |
import time | |
# --------------- | |
# Parameters | |
# --------------- | |
# Number of profile iterations to run | |
itt = 30 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.optim.sgd import sgd | |
import gc | |
import objgraph | |
import weakref | |
def all(): | |
# Only a subset of the args you could have | |
def set_sgd_hook(mod, p, lr, weight_decay, momentum): |
OlderNewer