Skip to content

Instantly share code, notes, and snippets.

@ShigekiKarita
Last active August 29, 2015 14:19
Show Gist options
  • Save ShigekiKarita/433d396627e120506ce2 to your computer and use it in GitHub Desktop.
Save ShigekiKarita/433d396627e120506ce2 to your computer and use it in GitHub Desktop.
easy cuBLAS macro
#include <cublas_v2.h>
namespace detail {
template <typename Dtype, typename Scalar>
using if_ = typename std::enable_if<
std::is_same<Dtype, Scalar>::value, cublasStatus_t>::type;
#define CUBLAS_FACTORY(funcname) \
template <typename Dtype, typename ... Args> \
if_<float, Dtype> \
funcname(Args ... args) \
{ \
return cublasS##funcname(args ...); \
} \
\
template <typename Dtype, typename ... Args> \
if_<double, Dtype> \
funcname(Args ... args) \
{ \
return cublasD##funcname(args ...); \
} \
// cublasSgemm_v2 (float) と cublasDgemm_v2 (double) を多重定義した関数 gemm_v2<Dtype> を作る
CUBLAS_FACTORY(gemm_v2);
CUBLAS_FACTORY(axpy_v2);
#undef CUBLAS_FACTORY
} // namespace detail
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment