Skip to content

Instantly share code, notes, and snippets.

@PatWie
Last active October 24, 2018 14:26
Show Gist options
  • Save PatWie/2ebfe6850117b78d6e4c08c0fd68ae02 to your computer and use it in GitHub Desktop.
Save PatWie/2ebfe6850117b78d6e4c08c0fd68ae02 to your computer and use it in GitHub Desktop.
CUDA Index
// Patrick Wieschollek <[email protected]>, 2018
#ifndef LIB_CUDA_INDEX_H_
#define LIB_CUDA_INDEX_H_
#include <array>
namespace cuda_index {
namespace impl {
// typedef unsigned long ulong;
typedef size_t ulong;
template <class Dtype>
constexpr Dtype product_impl(Dtype res) {
return res;
}
template <class Dtype, class... Dtypes>
constexpr Dtype product_impl(Dtype res, Dtype v0, Dtypes... vs) {
return product_impl(res * v0, vs...);
}
template <class Dtype, class... Dtypes>
constexpr Dtype product(Dtype v0, Dtypes... vs) {
const ulong res = product_impl(v0, vs...);
return res;
}
template <ulong skip, class Dtype, class... Dtypes>
struct skip_product_impl {
constexpr Dtype call(Dtype v0, Dtypes... vs) {
return skip_product_impl<skip - 1, Dtypes...>().call(vs...);
}
};
template <class Dtype, class... Dtypes>
struct skip_product_impl<0, Dtype, Dtypes...> {
constexpr Dtype call(Dtype v0, Dtypes... vs) { return product(v0, vs...); }
};
template <ulong skip, class Dtype, class... Dtypes>
constexpr Dtype skip_product(Dtype v0, Dtypes... vs) {
return skip_product_impl<skip, Dtype, Dtypes...>().call(v0, vs...);
}
template <ulong axis, typename... Dtype>
struct Index_helper {
__host__ __device__ constexpr ulong call(
const std::array<ulong, sizeof...(Dtype)> &multi_index, Dtype... Is) {
const ulong pitch = skip_product<sizeof...(Dtype) - axis, Dtype...>(Is...);
const ulong previous =
Index_helper<axis - 1, Dtype...>().call(multi_index, Is...);
return multi_index[sizeof...(Dtype) - axis - 1] * pitch + previous;
}
};
template <typename... Dtype>
struct Index_helper<0, Dtype...> {
__host__ __device__ constexpr ulong call(
const std::array<ulong, sizeof...(Dtype)> &multi_index, Dtype... Is) {
return 0;
}
};
template <ulong N, ulong axis, typename T, typename... Ts>
struct Valid_helper {
__host__ __device__ constexpr ulong call(
const std::array<ulong, N> &multi_index, T v0, Ts... vs) {
const ulong current_axis = N - axis - 1;
const ulong value = multi_index[N - axis - 1];
const ulong bound = v0;
return (value < bound) && Valid_helper<N, axis - 1, Ts...>().call(multi_index, vs...);
}
};
template <ulong N, typename T>
struct Valid_helper<N, 0, T> {
__host__ __device__ constexpr bool call(
const std::array<ulong, N> &multi_index, T ve) {
return multi_index[N - 1] < ve;
}
};
}; // namespace impl
/**
* @brief computes the index with given strides (without overhead in register usage)
* @details accessing a particular entry in a multi-dimensional array can be cumbersome
*
*
* const int actual = cuda_index::Index({1, 3, 5, 7}, B, H, W, C);
* const int expected = 1 * (H * W * C) + 3 * (W * C) + 5 * C + 7;
*
* @param multi_index index for each axis
* @param dimensions length of each axis
*
* @return correct position in multi-dimensional array
*/
template <typename... Ts>
__host__ __device__ constexpr ulong Idx(
const std::array<ulong, sizeof...(Ts)> &index, Ts... dimensions) {
return index[sizeof...(Ts) - 1] +
impl::Index_helper<sizeof...(Ts) - 1, Ts...>().call(
index, dimensions...);
}
template <typename... Ts>
__host__ __device__ constexpr bool IdxValid(
const std::array<ulong, sizeof...(Ts)> &index, Ts... dimensions) {
return impl::Valid_helper<sizeof...(Ts), sizeof...(Ts) - 1, Ts...>().call(
index, dimensions...);
}
}; // namespace cuda_index
#endif // LIB_CUDA_INDEX_H_
#include <iostream>
#include "cuda_index.h"
// nvcc example.cu --expt-relaxed-constexpr -Xptxas="-v" && ./a.out
using cuda_index::Idx;
using cuda_index::IdxValid;
__global__ void index_expected(int B, int H, int W, int C) {
const int idx = 1 * (H * W * C) + 3 * W * C + 5 * C + 7;
printf("value is %i\n", idx);
}
__global__ void index_actual(int B, int H, int W, int C) {
const int idx = Idx({1, 3, 5, 7}, B, H, W, C);
printf("value is %i\n", idx);
}
__global__ void valid_expected(int B, int H, int W, int C) {
if ((1 < B) && (3 < H) && (5 < W) && (7 < C)) {
printf("valid \n");
}
if ((1 < B) && (H + 1 < H) && (5 < W) && (7 < C)) {
printf("valid \n");
}
}
__global__ void valid_actual(int B, int H, int W, int C) {
if (IdxValid({1, 3, 5, 7}, B, H, W, C)) {
printf("valid \n");
}
if (IdxValid({1, H + 1, 5, 7}, B, H, W, C)) {
printf("valid \n");
}
}
int main(int argc, char const *argv[]) {
int B = 4;
int H = 17;
int W = 32;
int C = 128;
dim3 grid(1);
dim3 block(1);
index_expected<<<grid, block>>>(B, H, W, C);
index_actual<<<grid, block>>>(B, H, W, C);
valid_expected<<<grid, block>>>(B, H, W, C);
valid_actual<<<grid, block>>>(B, H, W, C);
cudaDeviceSynchronize();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment