Last active
April 5, 2016 09:11
-
-
Save coodoo/65c1d011a1a5c33d30b9885492063dc8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
- here's a better way to train RNNLM with large dataset, | |
- persist the weights from LookupTable into file | |
- and load it back later to use in sentiment analysis tasks | |
- detailed discussion here: https://github.com/torch/nn/issues/747 | |
- thanks @fmassa for pointing out the direct direction | |
- one usually simply copy the weights from a pre-trained model. For example, | |
old_lookup = torch.load(...) -- can be a full model | |
lookup = nn.LookupTable(...) | |
lookup.weight:copy(old_lookup.weight) | |
-- add lookup in your network | |
- Another way (if you only want to finetune the last layer for example) is to load the full network, | |
- removing the last layer and adding a new randomly-initialized layer on top of the network. | |
net = torch.load(...) | |
net:remove() -- remove last layer | |
net:add(nn.Linear(4096,21)) -- add new layer replacing the old one | |
]]-- | |
require 'nn' | |
require 'rnn' | |
criterion = nn.ClassNLLCriterion() | |
lookup = nn.LookupTable(3, 4) | |
mlp = nn.Sequential() | |
mlp:add( lookup) | |
mlp:add( nn.SplitTable(1, 2)) | |
mlp:add( nn.Sequencer(nn.FastLSTM(4, 4)) ) | |
mlp:add( nn.SelectTable(-1) ) | |
mlp:add( nn.Linear( 4, 3)) | |
mlp:add( nn.LogSoftMax() ) | |
-- mock questions and answers | |
inputs = torch.Tensor{{1, 2, 3}, {2, 3, 1}} | |
targets = torch.Tensor{3, 2} | |
-- training | |
outputs = mlp:forward(inputs) | |
err = criterion:forward(outputs, targets) | |
gradOutputs = criterion:backward(outputs, targets) | |
gradInputs = mlp:backward(inputs, gradOutputs) | |
print('write out') | |
print(mlp:get(1).weight) | |
torch.save('z1.t7', mlp:get(1).weight ) | |
old = torch.load('z1.t7') | |
newLookup = nn.LookupTable(3, 4) | |
newLookup.weight:copy(old) | |
print('after copy', newLookup.weight) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment