Created
April 23, 2016 21:17
-
-
Save JonathanRaiman/accddeb1ea3737d7db1bcd8036372350 to your computer and use it in GitHub Desktop.
Get reduction over all dimensions to work in mshadow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Reduction over all dimensions in mshadow. Requires changing the structs in | |
mshadow expression to store their input expressions by value instead of | |
reference. | |
Installation: | |
nvcc some_file.cu -std=c++11 -O3 -w -o some_file -I /usr/local/include | |
Usage: | |
./some_file | |
*/ | |
#include <iostream> | |
#include <cassert> | |
#include <chrono> | |
#include <unordered_map> | |
#include <mutex> | |
#include <atomic> | |
#include <iomanip> | |
// mshadow configuration options | |
#define MSHADOW_USE_MKL 0 | |
#define MSHADOW_USE_CUDA 1 | |
#include <mshadow/tensor.h> | |
#define EXP_F expf | |
// timer class for measuring total reduce vs. row reduce | |
class Timer { | |
typedef std::chrono::system_clock clock_t; | |
static std::unordered_map<std::string, std::atomic<int>> timers; | |
static std::mutex timers_mutex; | |
std::string name; | |
bool stopped; | |
bool started; | |
std::chrono::time_point<clock_t> start_time; | |
public: | |
// creates timer and starts measuring time. | |
Timer(std::string name, bool autostart=true); | |
// destroys timer and stops counting if the timer was not previously stopped. | |
~Timer(); | |
// explicitly start the timer | |
void start(); | |
// explicitly stop the timer | |
void stop(); | |
static void report(); | |
}; | |
std::unordered_map<std::string, std::atomic<int>> Timer::timers; | |
std::mutex Timer::timers_mutex; | |
Timer::Timer(std::string name, bool autostart) : name(name), | |
stopped(false), | |
started(false) { | |
if (timers.find(name) == timers.end()) { | |
std::lock_guard<decltype(timers_mutex)> guard(timers_mutex); | |
if (timers.find(name) == timers.end()) | |
timers[name] = 0; | |
} | |
if (autostart) | |
start(); | |
} | |
void Timer::start() { | |
assert(!started); | |
start_time = clock_t::now(); | |
started = true; | |
} | |
void Timer::stop() { | |
assert(!stopped); | |
timers[name] += std::chrono::duration_cast< std::chrono::milliseconds > | |
(clock_t::now() - start_time).count(); | |
stopped = true; | |
} | |
Timer::~Timer() { | |
if (!stopped) stop(); | |
} | |
void Timer::report() { | |
std::lock_guard<decltype(timers_mutex)> guard(timers_mutex); | |
for (auto& kv : timers) { | |
std::cout << "\"" << kv.first << "\" => " | |
<< std::fixed << std::setw(5) << std::setprecision(4) << std::setfill(' ') | |
<< (double) kv.second / 1000 << "s" << std::endl; | |
} | |
timers.clear(); | |
} | |
using namespace mshadow; | |
template<typename SrcExp, typename DType, int etype> | |
auto sum(const mshadow::expr::Exp<SrcExp, DType, etype> &exp) -> decltype(sum_rows(reshape(exp, Shape2(1, 1)))) { | |
return sum_rows(reshape(exp, Shape2(expr::ShapeCheck<expr::ExpInfo<SrcExp>::kDim, SrcExp> | |
::Check(exp.self()).Size(), 1))); | |
} | |
template<typename R> | |
struct sigmoid { | |
MSHADOW_XINLINE static R Map(const R& a) { | |
return 1.0 / (1.0 + EXP_F(-a)); | |
} | |
}; | |
int main() { | |
int size = 5000; | |
int NUM_EXPERIMENTS = 1000; | |
Tensor<gpu, 2, float> src(NULL, Shape2(size, size)); | |
Tensor<gpu, 1, float> total_reduce(NULL, Shape1(1)); | |
Tensor<gpu, 1, float> row_reduce(NULL, Shape1(size)); | |
AllocSpace(&src); | |
AllocSpace(&total_reduce); | |
AllocSpace(&row_reduce); | |
// input is 0 and we know sigmoid(0) = 0.5 | |
src = 0.0; | |
for (int i = 0; i < NUM_EXPERIMENTS; i++) { | |
Timer t1("total_reduce"); | |
total_reduce = sum(expr::F<sigmoid<float>>(src)); | |
} | |
for (int i = 0; i < NUM_EXPERIMENTS; i++) { | |
Timer t1("row_reduce"); | |
row_reduce = sum_rows(expr::F<sigmoid<float>>(src)); | |
} | |
Timer::report(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment