Last active
October 24, 2018 14:26
-
-
Save PatWie/2ebfe6850117b78d6e4c08c0fd68ae02 to your computer and use it in GitHub Desktop.
CUDA Index
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Patrick Wieschollek <[email protected]>, 2018 | |
#ifndef LIB_CUDA_INDEX_H_ | |
#define LIB_CUDA_INDEX_H_ | |
#include <array> | |
namespace cuda_index { | |
namespace impl { | |
// typedef unsigned long ulong; | |
typedef size_t ulong; | |
template <class Dtype> | |
constexpr Dtype product_impl(Dtype res) { | |
return res; | |
} | |
template <class Dtype, class... Dtypes> | |
constexpr Dtype product_impl(Dtype res, Dtype v0, Dtypes... vs) { | |
return product_impl(res * v0, vs...); | |
} | |
template <class Dtype, class... Dtypes> | |
constexpr Dtype product(Dtype v0, Dtypes... vs) { | |
const ulong res = product_impl(v0, vs...); | |
return res; | |
} | |
template <ulong skip, class Dtype, class... Dtypes> | |
struct skip_product_impl { | |
constexpr Dtype call(Dtype v0, Dtypes... vs) { | |
return skip_product_impl<skip - 1, Dtypes...>().call(vs...); | |
} | |
}; | |
template <class Dtype, class... Dtypes> | |
struct skip_product_impl<0, Dtype, Dtypes...> { | |
constexpr Dtype call(Dtype v0, Dtypes... vs) { return product(v0, vs...); } | |
}; | |
template <ulong skip, class Dtype, class... Dtypes> | |
constexpr Dtype skip_product(Dtype v0, Dtypes... vs) { | |
return skip_product_impl<skip, Dtype, Dtypes...>().call(v0, vs...); | |
} | |
template <ulong axis, typename... Dtype> | |
struct Index_helper { | |
__host__ __device__ constexpr ulong call( | |
const std::array<ulong, sizeof...(Dtype)> &multi_index, Dtype... Is) { | |
const ulong pitch = skip_product<sizeof...(Dtype) - axis, Dtype...>(Is...); | |
const ulong previous = | |
Index_helper<axis - 1, Dtype...>().call(multi_index, Is...); | |
return multi_index[sizeof...(Dtype) - axis - 1] * pitch + previous; | |
} | |
}; | |
template <typename... Dtype> | |
struct Index_helper<0, Dtype...> { | |
__host__ __device__ constexpr ulong call( | |
const std::array<ulong, sizeof...(Dtype)> &multi_index, Dtype... Is) { | |
return 0; | |
} | |
}; | |
template <ulong N, ulong axis, typename T, typename... Ts> | |
struct Valid_helper { | |
__host__ __device__ constexpr ulong call( | |
const std::array<ulong, N> &multi_index, T v0, Ts... vs) { | |
const ulong current_axis = N - axis - 1; | |
const ulong value = multi_index[N - axis - 1]; | |
const ulong bound = v0; | |
return (value < bound) && Valid_helper<N, axis - 1, Ts...>().call(multi_index, vs...); | |
} | |
}; | |
template <ulong N, typename T> | |
struct Valid_helper<N, 0, T> { | |
__host__ __device__ constexpr bool call( | |
const std::array<ulong, N> &multi_index, T ve) { | |
return multi_index[N - 1] < ve; | |
} | |
}; | |
}; // namespace impl | |
/** | |
* @brief computes the index with given strides (without overhead in register usage) | |
* @details accessing a particular entry in a multi-dimensional array can be cumbersome | |
* | |
* | |
* const int actual = cuda_index::Index({1, 3, 5, 7}, B, H, W, C); | |
* const int expected = 1 * (H * W * C) + 3 * (W * C) + 5 * C + 7; | |
* | |
* @param multi_index index for each axis | |
* @param dimensions length of each axis | |
* | |
* @return correct position in multi-dimensional array | |
*/ | |
template <typename... Ts> | |
__host__ __device__ constexpr ulong Idx( | |
const std::array<ulong, sizeof...(Ts)> &index, Ts... dimensions) { | |
return index[sizeof...(Ts) - 1] + | |
impl::Index_helper<sizeof...(Ts) - 1, Ts...>().call( | |
index, dimensions...); | |
} | |
template <typename... Ts> | |
__host__ __device__ constexpr bool IdxValid( | |
const std::array<ulong, sizeof...(Ts)> &index, Ts... dimensions) { | |
return impl::Valid_helper<sizeof...(Ts), sizeof...(Ts) - 1, Ts...>().call( | |
index, dimensions...); | |
} | |
}; // namespace cuda_index | |
#endif // LIB_CUDA_INDEX_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include "cuda_index.h" | |
// nvcc example.cu --expt-relaxed-constexpr -Xptxas="-v" && ./a.out | |
using cuda_index::Idx; | |
using cuda_index::IdxValid; | |
__global__ void index_expected(int B, int H, int W, int C) { | |
const int idx = 1 * (H * W * C) + 3 * W * C + 5 * C + 7; | |
printf("value is %i\n", idx); | |
} | |
__global__ void index_actual(int B, int H, int W, int C) { | |
const int idx = Idx({1, 3, 5, 7}, B, H, W, C); | |
printf("value is %i\n", idx); | |
} | |
__global__ void valid_expected(int B, int H, int W, int C) { | |
if ((1 < B) && (3 < H) && (5 < W) && (7 < C)) { | |
printf("valid \n"); | |
} | |
if ((1 < B) && (H + 1 < H) && (5 < W) && (7 < C)) { | |
printf("valid \n"); | |
} | |
} | |
__global__ void valid_actual(int B, int H, int W, int C) { | |
if (IdxValid({1, 3, 5, 7}, B, H, W, C)) { | |
printf("valid \n"); | |
} | |
if (IdxValid({1, H + 1, 5, 7}, B, H, W, C)) { | |
printf("valid \n"); | |
} | |
} | |
int main(int argc, char const *argv[]) { | |
int B = 4; | |
int H = 17; | |
int W = 32; | |
int C = 128; | |
dim3 grid(1); | |
dim3 block(1); | |
index_expected<<<grid, block>>>(B, H, W, C); | |
index_actual<<<grid, block>>>(B, H, W, C); | |
valid_expected<<<grid, block>>>(B, H, W, C); | |
valid_actual<<<grid, block>>>(B, H, W, C); | |
cudaDeviceSynchronize(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment