Last active
August 29, 2015 14:19
-
-
Save ShigekiKarita/433d396627e120506ce2 to your computer and use it in GitHub Desktop.
easy cuBLAS macro
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cublas_v2.h> | |
namespace detail { | |
template <typename Dtype, typename Scalar> | |
using if_ = typename std::enable_if< | |
std::is_same<Dtype, Scalar>::value, cublasStatus_t>::type; | |
#define CUBLAS_FACTORY(funcname) \ | |
template <typename Dtype, typename ... Args> \ | |
if_<float, Dtype> \ | |
funcname(Args ... args) \ | |
{ \ | |
return cublasS##funcname(args ...); \ | |
} \ | |
\ | |
template <typename Dtype, typename ... Args> \ | |
if_<double, Dtype> \ | |
funcname(Args ... args) \ | |
{ \ | |
return cublasD##funcname(args ...); \ | |
} \ | |
// cublasSgemm_v2 (float) と cublasDgemm_v2 (double) を多重定義した関数 gemm_v2<Dtype> を作る | |
CUBLAS_FACTORY(gemm_v2); | |
CUBLAS_FACTORY(axpy_v2); | |
#undef CUBLAS_FACTORY | |
} // namespace detail |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment