Skip to content

Instantly share code, notes, and snippets.

@kevinrobinson
Last active December 15, 2015 22:31
Show Gist options
  • Save kevinrobinson/5b8e04e3634b36f791ff to your computer and use it in GitHub Desktop.
Save kevinrobinson/5b8e04e3634b36f791ff to your computer and use it in GitHub Desktop.
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
static void launch(OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, Tensor* out) {
const uint64 m = a.dim_size(1 - dim_pair[0].first);
const uint64 k = a.dim_size(dim_pair[0].first);
const uint64 n = b.dim_size(1 - dim_pair[0].second);
// .. options for transposing the input matrices to the format cuBLAS expects ...
// Get a Stream for this GPUDevice
auto* stream = ctx->op_device_context<GPUDeviceContext>()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
auto a_ptr = AsDeviceMemory(a.template flat<T>().data());
auto b_ptr = AsDeviceMemory(b.template flat<T>().data());
auto c_ptr = AsDeviceMemory(out->template flat<T>().data());
// Launch the BLAS gemm kernel on the GPU stream, which will perform the matrix multiplication.
bool blas_launch_status = stream->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f,
b_ptr, transpose_b ? k : n, a_ptr,
transpose_a ? m : k, 0.0f, &c_ptr, n).ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed ..."));
}
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment