Skip to content

Instantly share code, notes, and snippets.

@killeent
Created March 15, 2017 18:29
Show Gist options
  • Save killeent/9ef59a44c270aee15d3712f102c1a2c9 to your computer and use it in GitHub Desktop.
Save killeent/9ef59a44c270aee15d3712f102c1a2c9 to your computer and use it in GitHub Desktop.
#ifndef THC_TENSOR_MODE_CUH
#define THC_TENSOR_MODE_CUH
#include "THCNumerics.cuh"
#include "THCSortUtils.cuh"
struct ThrustHalfLess
{
__host__ __device__ inline bool operator()(const half& lhs, const half& rhs) {
return THCNumerics<half>::lt(lhs, rhs);
}
};
struct ThrustHalfNotEqualTo
{
__host__ __device__ inline bool operator()(const half& lhs, const half& rhs) {
return THCNumerics<half>::ne(lhs, rhs);
}
};
struct ThrustHalfEqualTo
{
__host__ __device__ inline bool operator()(const half& lhs, const half& rhs) {
return THCNumerics<half>::eq(lhs, rhs);
}
};
struct ThrustHalfEqualToPredicate
{
ThrustHalfEqualToPredicate(half val): val_(val) {}
__host__ __device__ inline bool operator()(half x) {
return THCNumerics<half>::eq(val_, x);
}
half val_;
};
template <typename T>
struct BinaryAddOp {
__host__ __device__ inline T operator()(const T a, const T b) {
return THCNumerics<T>::add(a, b);
}
};
template <>
struct BinaryAddOp<unsigned int> {
__host__ __device__ inline unsigned int operator()(const unsigned int a, const unsigned int b) {
return a + b;
}
};
template <typename T, class BinaryOp, int Power2ScanSize>
__device__ void segmentedInclusivePrefixScan(T *smem, bool *bmem, BinaryOp binop) {
// Reduce step ("upsweep")
#pragma unroll
for (int stride = 1; stride < Power2ScanSize; stride <<= 1) {
int index = (threadIdx.x + 1) * stride * 2 - 1;
if (index < Power2ScanSize) {
smem[index] = bmem[index] ? smem[index] : binop(smem[index], smem[index - stride]);
bmem[index] = bmem[index] | bmem[index - stride];
}
__syncthreads();
}
// Post-reduce step ("downsweep")
#pragma unroll
for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) {
int index = (threadIdx.x + 1) * stride * 2 - 1;
if ((index + stride) < Power2ScanSize) {
smem[index + stride] = bmem[index + stride] ? smem[index + stride] : binop(smem[index + stride], smem[index]);
bmem[index + stride] = bmem[index + stride] | bmem[index];
}
__syncthreads();
}
}
template <typename T>
__device__ inline void printSmem(T *smem, int sliceSize) {
}
template <>
__device__ inline void printSmem<int>(int *smem, int sliceSize) {
for (int i = 0; i < sliceSize; ++i) {
printf(" %d\n", smem[i]);
}
}
template <>
__device__ inline void printSmem<float>(float *smem, int sliceSize) {
for (int i = 0; i < sliceSize; ++i) {
printf(" %f\n", smem[i]);
}
}
template <typename T>
__device__ inline void printVal(T val) {
}
template <>
__device__ inline void printVal<int>(int val) {
printf("mode: %d\n", val);
}
template <>
__device__ inline void printVal<float>(float val) {
printf("mode: %f\n", val);
}
// The mode kernel has the following characteristics: It uses internal shared memory
// buffers of Power2Size, which must be greater than the number of elements. Additionally,
// there is one block for every slice to calculate the mode for, and in each block there
// is one thread for every two elements.
//
// Both sorted and positions are assumed to be contiguous Tensors with the mode dimension
// as the innermost dim, such that we can get the particular slice for a Tensor via its
// linear block dimension * the slice size.
template <typename T, unsigned int Power2Size>
__global__ void computeMode(
T *input,
TensorInfo<T, unsigned int> values,
TensorInfo<long, unsigned int> indices,
long sliceSize)
{
int tidx = threadIdx.x;
int stidx = blockDim.x + threadIdx.x; // Second Index this thread responsible for
// First, we need to calculate the offset into the sorted Tensor that represents
// the start of the slice for this block to calculate the mode for. This offset
// is a combination of the gridIndices, and the number of elements in the slice.
unsigned int blockId = getLinearBlockId<unsigned int>();
unsigned int linearOffset = blockId * sliceSize;
// The smem buffer is used to store the elements from the slice
__shared__ T smem[Power2Size];
// Each thread loads up to two elements from the Tensor into shared memory
if (tidx < sliceSize) {
smem[tidx] = input[linearOffset + tidx];
}
if (stidx < sliceSize) {
smem[stidx] = input[linearOffset + stidx];
}
// The fmem buffer is used in multiple components of the kernel. First, it stores
// whether bmem[i] = i < sliceSize to mark the valid components in the smem buffer
// for sorting
__shared__ bool bmem[Power2Size];
bmem[tidx] = tidx < sliceSize;
bmem[stidx] = stidx < sliceSize;
__syncthreads(); // barrier for smem, bmem initialization
// First, sort the input slice in ascending order. smem contains the input
// elements, and bmem marks the valid indices
bitonicSortKeys<LTComp<T>, T, unsigned int, Power2Size>(smem, bmem, LTComp<T>());
__syncthreads(); // make no assumptions that the sort syncs at end
// The next step of our algorithm is performing a block-wide comparison of
// neighboring elements. In particular, given an sorted input slice A, we
// produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise 0.
//
// Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8]
// B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
//
// We re-use the bmem buffer for this computation. In particular, we can think of
// B[i] = true indicating the start of a sequence of equal values in the sorted
// list.
if (tidx == 0) {
bmem[tidx] = true; // for setting element 0
}
// Compares elements (0, 1), (2, 3), ... and sets 1, 3, ...
bmem[tidx * 2 + 1] = THCNumerics<T>::ne(smem[tidx * 2], smem[tidx * 2 + 1]); // (0, 1), (1, 2), etc.
// Compares elements (1, 2), (3, 4), ... and sets 2, 4, ...
if (((tidx + 1) * 2) < Power2Size) {
bmem[(tidx + 1) * 2] = THCNumerics<T>::ne(smem[((tidx + 1) * 2) - 1], smem[(tidx + 1) * 2]);
}
__syncthreads(); // barrier for bmem initialization
// Next, we initialize another shared memory buffer cmem which will be used in the
// Segmented Prefix Sum
__shared__ unsigned int cmem[Power2Size];
// We set cmem to be the negation of bmem. In particular, we can think of cmem[i] = true
// iff A[i-1] == A[i] in our original sorted slice.
cmem[tidx] = !bmem[tidx];
cmem[stidx] = !bmem[stidx];
__syncthreads(); // barrier for cmem initialization
// Next, we perform a segmented prefix sum on the neighboring elements, where
// the presence of a one indicates the start of a segment. In this case bmem acts
// as the segment start flags, and cmem is the buffer to be summed:
//
// Input (cmem) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
// Flag (bmem) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
// Output (cmem) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0]
//
// Afterwards, the cmem buffer contains the lengths of the segments (minus 1), i.e. the counts
// of each element in the original input.
segmentedInclusivePrefixScan<unsigned int, BinaryAddOp<unsigned int>, Power2Size>(cmem, bmem, BinaryAddOp<unsigned int>());
// Our last shared memory buffer is used to track indices
__shared__ unsigned int imem[Power2Size];
// initialize the indices buffer such that imem[i] = i
imem[tidx] = tidx;
imem[stidx] = stidx;
__syncthreads(); // barrier for both the scan and the imem initialization
// At this point, we need to find the maximum element in the cmem buffer.
// This element will represent the count (-1) of the mode. Because of the
// way we have set up the problem, the index where this mode occurs will
// also be the location of the mode value in the sorted array, e.g.
//
// smem = [0, 0, 1, 1, 1, 2]
// cmem = [0, 1, 0, 1, 2, 0]
// ^
// maximum value, also aligned with mode = 1
//
// We perform a block wide max-reduction of the cmem buffer, and bring imem
// along with it.
//
// Loop 1 (Power2Size = offset = 4):
//
// (0, 4) --> cmem[4] = 2 > cmem[0] = 0, so update cmem[0], imem[0]
// (1, 5) --> cmem[5] = 0 <= cmem[1] = 1, do nothing
//
// Now: 0 1 2 3 4 5
// cmem = [2, 1, 0, 1, 2, 0]
// imem = [4, 1, 2, 3, 4, 5]
//
// Loop 2 (offset = 2)
//
// (0, 2) --> cmem[2] == 0 <= cmem[0] = 2, do nothing
// (1, 3) --> cmem[3] == 1 <= cmem[1] = 1, do nothing
//
// Now: 0 1 2 3 4 5
// cmem = [2, 1, 2, 1, 2, 0]
// imem = [4, 1, 4, 3, 4, 5]
//
// Loop 3 (offset = 1)
//
// (0, 1) --> cmem[1] == 1 <= cmem[0] = 2, do nothing
//
// So at the end at cmem[0] we have the maximum count = 2, and the
// corresponding index = 4
#pragma unroll
for (unsigned int offset = Power2Size / 2; offset > 0; offset >>= 1) {
if (tidx < offset && tidx + offset < sliceSize) {
// Note that we could do >= as well. We use >, so that we pick the
// earliest maximum value in the sequence in case of ties. This will
// result in picking the smallest value for the mode, i.e. if both
// 3 and 4 occur the same number of times in the input, and their count
// is the mode, then we return 3. This mimics the behavior of CPU-Torch
if (cmem[tidx + offset] > cmem[tidx]) {
cmem[tidx] = cmem[tidx + offset];
imem[tidx] = imem[tidx + offset];
}
}
__syncthreads();
}
// Store the mode in shared memory for use in finding the mode in the input slice
__shared__ T mode;
// Given the above constraints, the mode is the value at the maximum index in the segmented scan
if (tidx == 0) {
mode = smem[imem[0]];
}
__syncthreads(); // broadcast mode
// Finally, we need to find the "an" index of the mode in the input Tensor. The API does
// not constrain which index we pick, so it can be any of the indices that contain the mode.
// We will do a reduction to find the index. First, we mark indices that are equal to the mode,
// i.e bmem[i] = true if input[i] == mode
if (tidx < sliceSize) {
bmem[tidx] = THCNumerics<T>::eq(input[linearOffset + tidx], mode);
imem[tidx] = tidx;
}
if (stidx < sliceSize) {
bmem[stidx] = THCNumerics<T>::eq(input[linearOffset + stidx], mode);
imem[stidx] = stidx;
}
__syncthreads(); // barrier for initialization of bmem, imem
// Then we perform a similar reduction to the one above, except this time we update
// the element if the element at the base position is not equal to the mode and
// the element at the offset position is. At the end, imem[0] will contain an index
// with the mode.
for (unsigned int offset = Power2Size / 2; offset > 0; offset >>= 1) {
if (tidx < offset && tidx + offset < sliceSize) {
// Just always update the base if the offset is true
if (bmem[tidx + offset]) {
imem[tidx] = imem[tidx + offset];
bmem[tidx] = true; // need to update match
}
}
__syncthreads();
}
// Finally, we have the mode, and an index where it occurs. We use a single thread
// to place this in the appropriate output position
if (tidx == 0) {
long index = TH_INDEX_BASE + imem[0];
unsigned int outputOffset = IndexToOffset<T, unsigned int, -1>::get(blockId, values);
values.data[outputOffset] = mode;
indices.data[outputOffset] = index;
}
}
#endif // THC_TENSOR_MODE_CUH
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment