Created
March 20, 2018 23:23
-
-
Save s1998/d6eb3506adaa3cdf1008dee0e1d53fff to your computer and use it in GitHub Desktop.
BRNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// There are 2 ways of doing this that I could think of. | |
// A. | |
// We deifne a bidirectional layer that takes forward dircetional rnn cell (lstm/gru..) and one backward direcional rnn cell | |
BidirectionalLayer<> b_unit; | |
// This unit has a Add() method. The method is "different" from Add method of other layers because the user needs to | |
// specify forward RNN cell and backward RNN cell. | |
// There are 2 ways I could think of doing this (A1): | |
template <bool forwardPolicy, class LayerType, class... Args> | |
void Add(Args... args){...} | |
// or (A2) : | |
template <class LayerType, class... Args> | |
void Add(bool forwardPolicy, Args... args){...} | |
// Next whenever we add Bidirectional<> layer in RNN module, we remember the index of that layer. | |
// And instead of doing forward pass time-step wise, we do it layer wise: | |
for(size_t seq_num =0; seqNum < rho; seq_num++) | |
{ | |
Forward(stepData); | |
.... | |
} | |
// changes to | |
// do forward pass for first layer, store them | |
for(size_t seq_num =0; seqNum < rho; seq_num++) | |
{ | |
stepData = ... | |
boost::apply_visitor(ForwardVisitor(...), network.front()); | |
boost::apply_visitor(SaveOutputParameterVisotr(...), network.front()); | |
} | |
// do forward pass for other layers, and store them | |
for(size_t i = 1; i< network.size(); i++) | |
{ | |
if (current layer is bidirectional) | |
{ | |
// send the entire input (all time steps) to it | |
forward_data = ... | |
for(size_t seq_num =0; seqNum < rho; seq_num++) | |
{ | |
boost::apply_visitor(LoadOutputParameterVisotr(...), network.front()); | |
} | |
boost::apply_visitor(ForwardVisitor(...), network.front()); | |
} | |
else | |
{ | |
for(size_t seq_num =0; seqNum < rho; seq_num++) | |
{ | |
stepData = ... | |
boost::apply_visitor(ForwardVisitor(...), network.front()); | |
boost::apply_visitor(SaveOutputParameterVisotr(...), network.front()); | |
} | |
} | |
} | |
// See this image : | |
// https://drive.google.com/file/d/1bm2icsD4palEk5vfxbyIIuat69jAp-Xf/view?usp=sharing | |
// B. | |
// We change the API of RNN Add() method to let it know that currently inserted layer is back-direction RNN cell | |
// Usual add layer stays | |
template <class LayerType, class... Args> | |
void Add(Args... args){...} | |
// But we also add this 'Add()' method. | |
template <bool forwardPolicy, size_t depends, class LayerType, class... Args> | |
void Add(Args... args){...} | |
// Usual add layer calls it with Add<true, 1, ...>(...); | |
// depends = 1 means it depends on output of previous layer. | |
// depends = 2 means it depends on output of previous to previous layer. | |
// And then forward changes to : | |
// Assuming there is only on forward RNN cell and backward RNN cell in RNN and backward cell comes after forward. | |
// Though current implementation supports stacked RNNs, BRNNs should also have this support. | |
for(size_t i = 1; i< index_of_backward_layer; i++) | |
{ | |
for(size_t seq_num =0; seqNum < rho; seq_num++) | |
{ | |
stepData = ... | |
boost::apply_visitor(ForwardVisitor(...), network.front()); | |
boost::apply_visitor(SaveOutputParameterVisotr(...), network.front()); | |
} | |
} | |
for(size_t i = index_of_backward_layer; i < network.size(); i++) | |
{ | |
for(size_t seq_num = rho-1; seqNum >= 0 ; seq_num--) | |
{ | |
stepData = ... | |
boost::apply_visitor(ForwardVisitor(...), network.front()); | |
boost::apply_visitor(SaveOutputParameterVisotr(...), network.front()); | |
} | |
} | |
// See this image : | |
// https://drive.google.com/open?id=12RzlfXR9ryYxuqXKUsBAnUuZuYAbA8qN | |
// In both cases the backward and gradient pass needs to be modified. | |
// Tests : | |
// Use (small)MNIST dataset available and check that it givs decent accuracy skipping local minima. | |
// Repeat the MNIST test with only backward direction RNN. | |
// Repeat the rebber grammar test with only backward directon RNN. | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment