Last active
December 15, 2015 22:31
-
-
Save kevinrobinson/5b8e04e3634b36f791ff to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> { | |
static void launch(OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, Tensor* out) { | |
const uint64 m = a.dim_size(1 - dim_pair[0].first); | |
const uint64 k = a.dim_size(dim_pair[0].first); | |
const uint64 n = b.dim_size(1 - dim_pair[0].second); | |
// .. options for transposing the input matrices to the format cuBLAS expects ... | |
// Get a Stream for this GPUDevice | |
auto* stream = ctx->op_device_context<GPUDeviceContext>()->stream(); | |
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); | |
auto a_ptr = AsDeviceMemory(a.template flat<T>().data()); | |
auto b_ptr = AsDeviceMemory(b.template flat<T>().data()); | |
auto c_ptr = AsDeviceMemory(out->template flat<T>().data()); | |
// Launch the BLAS gemm kernel on the GPU stream, which will perform the matrix multiplication. | |
bool blas_launch_status = stream->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f, | |
b_ptr, transpose_b ? k : n, a_ptr, | |
transpose_a ? m : k, 0.0f, &c_ptr, n).ok(); | |
if (!blas_launch_status) { | |
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed ...")); | |
} | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment