Last active
September 6, 2018 11:48
-
-
Save FilippoC/996a2b1744be01e0ca4b12547332356a to your computer and use it in GitHub Desktop.
Dynet: Static vs Dynamic CG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <tuple> | |
#include <vector> | |
#include <random> | |
#include <chrono> | |
#include "dynet/dynet.h" | |
#include "dynet/expr.h" | |
struct Conv | |
{ | |
std::vector<dynet::Parameter> fs; | |
std::vector<dynet::Parameter> bs; | |
Conv( | |
dynet::ParameterCollection& pc, | |
const std::tuple<unsigned, unsigned, unsigned>& inpt_shape, | |
const std::vector<unsigned>& channels, | |
const unsigned kernel_size=3 | |
) | |
{ | |
auto H = std::get<0>(inpt_shape); | |
auto W = std::get<1>(inpt_shape); | |
auto C = std::get<2>(inpt_shape); | |
for (const auto channel : channels) | |
{ | |
fs.push_back(pc.add_parameters({kernel_size, kernel_size, C, channel})); | |
bs.push_back(pc.add_parameters({channel})); | |
H = (H - 2u * (kernel_size / 2u)) / 2u; | |
W = (W - 2u * (kernel_size / 2u)) / 2u; | |
C = channel; | |
} | |
} | |
dynet::Expression operator()( | |
dynet::ComputationGraph& cg, | |
dynet::Expression& input | |
) | |
{ | |
auto out = input; | |
for (unsigned i = 0u ; i < fs.size() ; ++i) | |
{ | |
auto f = dynet::parameter(cg, fs.at(i)); | |
auto b = dynet::parameter(cg, bs.at(i)); | |
out = dynet::conv2d(out, f, b, {1, 1}); | |
out = dynet::maxpooling2d(out, {2, 2}, {2, 2}); | |
out = dynet::rectify(out); | |
} | |
return out; | |
} | |
}; | |
void fill_random(dynet::real* start, const unsigned size) | |
{ | |
std::default_random_engine generator; | |
std::normal_distribution<dynet::real> distribution; | |
for (unsigned i = 0u ; i < size ; ++i) | |
start[i] = distribution(generator); | |
} | |
int main(int argc, char** argv) | |
{ | |
dynet::initialize(argc, argv); | |
dynet::ParameterCollection pc; | |
const unsigned n_iter = 10u; | |
const unsigned L = 256u; | |
const unsigned sample_size = L * L * 3u; | |
const unsigned data_size = 100u; | |
const unsigned batch_size = 50u; | |
dynet::real* data; | |
data = new float[data_size * sample_size]; | |
fill_random(data, data_size * sample_size); | |
std::random_device rd; | |
std::mt19937 gen(rd()); | |
std::uniform_int_distribution<> random_sample(0, (int) data_size - 1); | |
Conv conv(pc, std::make_tuple(L, L, 3), {32, 32}); | |
std::vector<dynet::real> batch(batch_size * sample_size); | |
{ | |
// Recompute graph at each it | |
auto start = std::chrono::steady_clock::now(); | |
float input_time = 0.f; | |
float cg_time = 0.f; | |
float fb_time = 0.f; | |
for (unsigned iter = 0u ; iter < n_iter ; ++iter) | |
{ | |
auto start_input = std::chrono::steady_clock::now(); | |
for (unsigned i = 0u ; i < batch_size ; ++i) | |
{ | |
const unsigned j = random_sample(gen); | |
std::memcpy( | |
batch.data() + i * sample_size, // destination | |
data + j * sample_size, // source | |
sample_size // size | |
); | |
} | |
auto end_input = std::chrono::steady_clock::now(); | |
auto start_cg = std::chrono::steady_clock::now(); | |
dynet::ComputationGraph cg; | |
auto input = dynet::input(cg, dynet::Dim({L, L, 3}, batch_size), batch); | |
auto out = conv(cg, input); | |
out = dynet::sum_elems(out); | |
out = dynet::sum_batches(out); | |
auto end_cg = std::chrono::steady_clock::now(); | |
auto start_fb = std::chrono::steady_clock::now(); | |
cg.forward(out); | |
cg.backward(out); | |
auto end_fb = std::chrono::steady_clock::now(); | |
input_time += std::chrono::duration_cast<std::chrono::milliseconds>(end_input - start_input).count(); | |
cg_time += std::chrono::duration_cast<std::chrono::milliseconds>(end_cg - start_cg).count(); | |
fb_time += std::chrono::duration_cast<std::chrono::milliseconds>(end_fb - start_fb).count(); | |
} | |
auto end = std::chrono::steady_clock::now(); | |
float time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count(); | |
std::cerr << "Dynamic graph" << std::endl; | |
std::cerr << "#############" << std::endl; | |
std::cerr << "Total time: " << time / 1000.f << std::endl; | |
std::cerr << "Batch creation time: " << input_time / 1000.f << std::endl; | |
std::cerr << "CG time: " << cg_time / 1000.f << std::endl; | |
std::cerr << "Forward/Backward time: " << fb_time / 1000.f << std::endl; | |
} | |
std::cerr << std::endl; | |
{ | |
// Use the same graph at each it | |
auto start = std::chrono::steady_clock::now(); | |
float input_time = 0.f; | |
float cg_time = 0.f; | |
float fb_time = 0.f; | |
auto start_cg = std::chrono::steady_clock::now(); | |
dynet::ComputationGraph cg; | |
// use ref to the input data here | |
auto input = dynet::input(cg, dynet::Dim({L, L, 3}, batch_size), &batch); | |
auto out = conv(cg, input); | |
out = dynet::sum_elems(out); | |
out = dynet::sum_batches(out); | |
auto end_cg = std::chrono::steady_clock::now(); | |
for (unsigned iter = 0u ; iter < n_iter ; ++iter) | |
{ | |
auto start_input = std::chrono::steady_clock::now(); | |
for (unsigned i = 0u ; i < batch_size ; ++i) | |
{ | |
const unsigned j = random_sample(gen); | |
std::memcpy( | |
batch.data() + i * sample_size, // destination | |
data + j * sample_size, // source | |
sample_size // size | |
); | |
} | |
auto end_input = std::chrono::steady_clock::now(); | |
auto start_fb = std::chrono::steady_clock::now(); | |
cg.forward(out); | |
cg.backward(out); | |
auto end_fb = std::chrono::steady_clock::now(); | |
input_time += std::chrono::duration_cast<std::chrono::milliseconds>(end_input - start_input).count(); | |
cg_time += std::chrono::duration_cast<std::chrono::milliseconds>(end_cg - start_cg).count(); | |
fb_time += std::chrono::duration_cast<std::chrono::milliseconds>(end_fb - start_fb).count(); | |
} | |
auto end = std::chrono::steady_clock::now(); | |
float time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count(); | |
std::cerr << "Static graph" << std::endl; | |
std::cerr << "############" << std::endl; | |
std::cerr << "Total time: " << time / 1000.f << std::endl; | |
std::cerr << "Batch creation time: " << input_time / 1000.f << std::endl; | |
std::cerr << "CG time: " << cg_time / 1000.f << std::endl; | |
std::cerr << "Forward/Backward time: " << fb_time / 1000.f << std::endl; | |
} | |
delete[] data; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment