Created
October 28, 2018 12:29
-
-
Save FilippoC/9172a0513d85629df2a24ad4a98e6311 to your computer and use it in GitHub Desktop.
Dynet autobatching issue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <vector> | |
#include <iostream> | |
#include <stdlib.h> | |
#include "dynet/dynet.h" | |
#include "dynet/grad-check.h" | |
#include "dynet/param-init.h" | |
std::default_random_engine generator; | |
std::uniform_real_distribution<float> distribution(-100.f, 100.f); | |
bool test_sparsemax(dynet::ParameterCollection& pc, dynet::Parameter& p_weights, unsigned rows, unsigned batches, bool manual_batching) | |
{ | |
dynet::ComputationGraph cg; | |
cg.set_immediate_compute(true); | |
cg.set_check_validity(true); | |
auto e_weights = dynet::parameter(cg, p_weights); | |
e_weights = dynet::reshape(e_weights, dynet::Dim({rows, 1}, batches)); | |
if (manual_batching) | |
{ | |
std::vector<dynet::Expression> v_values; | |
for (unsigned b = 0u ; b < batches ; ++b) | |
v_values.push_back( | |
dynet::sparsemax( | |
dynet::strided_select( | |
e_weights, | |
{(int) 1u, (int) 1u, (int) 1u}, | |
{(int) 0, (int) 0, (int) b}, | |
{(int) rows, (int) 1, (int) b+1} | |
) | |
) | |
); | |
for (unsigned b = 0u ; b < batches ; ++b) | |
{ | |
for (unsigned r = 0u ; r < rows ; ++r) | |
{ | |
auto e_output = dynet::strided_select( | |
v_values[b], | |
{(int) 1u}, | |
{(int) r}, | |
{(int) r+1} | |
); | |
auto v = check_grad(pc, e_output, 0); | |
if (v == 0) | |
{ | |
//check_grad(pc, e_output, 2); | |
return false; | |
} | |
} | |
} | |
return true; | |
} | |
else | |
{ | |
auto e_values = dynet::sparsemax(e_weights); | |
for (unsigned b = 0u ; b < batches ; ++b) | |
{ | |
for (unsigned r = 0u ; r < rows ; ++r) | |
{ | |
auto e_output = dynet::strided_select( | |
e_values, | |
{(int) 1u, (int) 1u, (int) 1u}, | |
{(int) r, (int) 0, (int) b}, | |
{(int) r+1, (int) 1, (int) b+1} | |
) | |
; | |
auto v = check_grad(pc, e_output, 0); | |
if (v == 0) | |
{ | |
//check_grad(pc, e_output, 2); | |
return false; | |
} | |
} | |
} | |
return true; | |
} | |
} | |
int main(int argc, char* argv[]) | |
{ | |
dynet::initialize(argc, argv); | |
const unsigned rows = 3u; | |
const unsigned batches = 3u; | |
std::vector<dynet::real> weights{ | |
-99.9984, -73.6924, 51.1211, | |
-8.26997, 6.55344, -56.2082, | |
-90.5911, 35.7729, 35.8593 | |
}; | |
dynet::ParameterCollection pc; | |
auto p_weights = pc.add_parameters( | |
{(unsigned int) weights.size()}, | |
dynet::ParameterInitFromVector(weights) | |
); | |
auto t1 = test_sparsemax(pc, p_weights, rows, batches, false); | |
std::cerr << "Automatic batching: " << (t1 ? "pass" : "fail") << "\n"; | |
auto t2 = test_sparsemax(pc, p_weights, rows, batches, true); | |
std::cerr << "Manual batching: " << (t2 ? "pass" : "fail") << "\n"; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment