Skip to content

Instantly share code, notes, and snippets.

@killeent
Created June 27, 2017 18:33
Show Gist options
  • Save killeent/9ced5954dd3e4f2ec29a089e8358a13d to your computer and use it in GitHub Desktop.
Save killeent/9ced5954dd3e4f2ec29a089e8358a13d to your computer and use it in GitHub Desktop.
// Super Dumb Kernel
__device__ __forceinline__ long calculateOffset(
long index, // index to calculate offset for
int ndim, // number of dimensions in Tensor
long sizes[8], // sizes for Tensor dims (either from the Tensor, or the size of the adv indexer at that dim)
long strides[8], // strides for Tensor
bool adv[8], // which Tensors are advanced indexers
long *advIndexTensors[8], // Adv Indexing Tensors
)
{
long offset = 0;
for (int dim = ndim - 1; ndim >= 0; --dim) {
long sizeAtDim, strideAtDim, indexAtDim, nextIndex;
strideAtDim = strides[dim];
sizeAtDim = sizes[dim];
if (adv[dim]) {
indexAtDim = advIndexTensors[dim][index % sizeAtDim];
if (dim > 0 && adv[dim - 1]) {
nextIndex = index;
} else {
nextIndex = index / sizeAtDim;
}
} else {
indexAtDim = index % sizeAtDim;
nextIndex = index / sizeAtDim;
}
offset += indexAtDim * strideAtDim;
index = nextIndex;
}
return offset;
}
__global__ void calculateLinearIndices(
long *output, // output Tensor for indices
int elements, // number of elements in output <-> indices to calculate
int ndim, // number of dimensions in Tensor
long sizes[8], // sizes for Tensor dims (either from the Tensor, or the size of the adv indexer at that dim)
long strides[8], // strides for Tensor
bool adv[8], // which Tensors are advanced indexers
long *advIndexTensors[8], // Adv Indexing Tensors
)
{
for (long i = blockIdx.x * blockDim.x + threadIdx.x;
i < elements;
i += blockDim.x * gridDim.x) {
output[i] = calculateOffset(i, ndim, sizes, strides, adv, advIndexTensors);
}
}
// Basic Launch stuff, assume we have all of the parameters needed
{
// initialize output
long *output;
cudaMalloc(...);
dim3 block = getApplyBlock();
dim3 grid;
getApplyGrid(state, elements, grid);
calculateLinearIndices<<<grid, block, 0, curr_stream>>>(
output,
...
)
THCudaCheck(cudaGetLastError());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment