Last active
May 1, 2017 14:48
-
-
Save farrajota/358de9ab06ebd8542f23f2102fb9c45a to your computer and use it in GitHub Desktop.
Small example on how to create a binary tree in Torch7 using NN containers and nngraph.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
Create a binary tree using two ways: NN containers and nn.gModule containers. | |
This example is fairly simple, and the default fully-connected layers are all | |
of size 100. However, this should also be simple to modify to allow different | |
fc layers with varying inputs/outputs if desired (for example: input a table | |
storing input+output configuration values for each of the sub-branch's level). | |
]] | |
require 'nn' | |
require 'nngraph' | |
-- (1) Example of a N-dimensional binary tree using containers | |
-- add tree branches into a container recursively | |
local function recursiveAddSubtrees(network, n_sub_trees, ninputs, noutputs) | |
-- creates a pair of linear layers inside a container | |
local function createPairFC(ninputs, noutputs) | |
local fc1 = nn.Linear(ninputs,noutputs) | |
local fc2 = nn.Linear(ninputs,noutputs) | |
local pair = nn.Sequential() | |
pair:add(nn.ConcatTable():add(nn.Sequential():add(fc1)):add(nn.Sequential():add(fc2))) | |
return pair | |
end --local function | |
-- | |
if n_sub_trees > 1 then | |
local pair_tree = createPairFC(ninputs, noutputs) | |
network:add(pair_tree) | |
recursiveAddSubtrees(pair_tree.modules[1].modules[1], n_sub_trees-1, ninputs, noutputs) | |
recursiveAddSubtrees(pair_tree.modules[1].modules[2], n_sub_trees-1, ninputs, noutputs) | |
else | |
network:add(createPairFC(ninputs, noutputs)) | |
end | |
end --local function | |
-- create a tree of NN containers | |
local n_sub_trees = 3 -- number of branching trees | |
local branches = nn.Sequential() -- Define the container to add sub-trees | |
recursiveAddSubtrees(branches, 3, 100, 100) -- recursively adds trees of fc layers of 100,100 | |
local bin_tree_model = nn.Sequential() -- main model container | |
bin_tree_model:add(nn.Linear(10,100)) -- add 'root' fully-connected layer | |
bin_tree_model:add(branches) -- add branches | |
print(bin_tree_model) -- print the binary tree | |
-- (2) Example of a N-dimensional binary tree using nngraph. Warning: requires qlua to display the graph | |
local function recursiveAddSubtreesGraph(networkGraphTable, root_fc, n_sub_trees, ninputs, noutputs) | |
-- creates a pair of linear layers | |
local function createPairFCGraph(ninputs, noutputs) | |
local fc1 = nn.Linear(ninputs,noutputs) | |
local fc2 = nn.Linear(ninputs,noutputs) | |
return fc1, fc2 | |
end --local function | |
-- | |
if n_sub_trees > 1 then | |
local fc1, fc2 = createPairFCGraph(ninputs, noutputs) | |
recursiveAddSubtreesGraph(networkGraphTable, fc1({root_fc}), n_sub_trees-1, ninputs, noutputs) | |
recursiveAddSubtreesGraph(networkGraphTable, fc2({root_fc}), n_sub_trees-1, ninputs, noutputs) | |
else | |
local fc1, fc2 = createPairFCGraph(ninputs, noutputs) | |
table.insert(networkGraphTable, fc1({root_fc})) | |
table.insert(networkGraphTable, fc2({root_fc})) | |
end | |
end --local function | |
-- create a tree of nngraph.Node | |
local networkGraphTable = {} -- this table will sotre all the outputs necessary to define the nn.gModule | |
local root_fc = nn.Linear(10,100)() -- root fc layer. The () at the end is to register the module as a "nngraph.Node" | |
recursiveAddSubtreesGraph(networkGraphTable, root_fc, 3, 100, 100) -- recursively keep adding pairs of fc layers | |
-- define the nn.gModule (nngraph) | |
local bin_tree_modelGraph = nn.gModule( | |
{root_fc}, -- define the input to the model | |
networkGraphTable -- define the outputs of the model | |
) | |
local ok = pcall(require,'qt') | |
if ok then | |
-- to plot this graph you should use qlua to start this script | |
graph.dot(bin_tree_modelGraph.fg, 'binary tree') --display the forward node graph of the binary tree | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment