Last active
October 1, 2018 21:36
-
-
Save bkj/17ab7e2afc881b47b9ffb638ce2c141c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// raw CUDA implementation of | |
// ```python | |
// x = np.random.uniform((0, 1), (1000, 10)) | |
// col_max = x.max(axis=0, keepdims=True) | |
// x = x - col_max | |
// x = exp(x) | |
// exp_col_sum = x.sum(axis=0, keepdims=True) | |
// x = log(x) - exp_col_sum | |
// ``` | |
__global__ void __NormProb(double * d_x, int num_cols, int num_rows) | |
{ | |
extern __shared__ double sdata[]; | |
int row = threadIdx.y; | |
int col = blockIdx.x; | |
int offset = row * num_cols + col; | |
// --------------------------------------- | |
// Compute max per column | |
if(row < num_rows) { | |
sdata[row] = d_x[offset]; | |
} | |
__syncthreads(); | |
for(unsigned int stride = blockDim.y / 2; stride > 0; stride >>= 1) { | |
if((row < stride) && (row + stride < num_rows)) { | |
sdata[row] = max(sdata[row], sdata[row + stride]); | |
} | |
__syncthreads(); | |
} | |
// --------------------------------------- | |
// Subtract max value | |
double max_value = sdata[0]; | |
if(row < num_rows) { | |
sdata[row] = exp(d_x[offset] - max_value); | |
} | |
__syncthreads(); | |
// --------------------------------------- | |
// Compute sum of exp'd values | |
for(unsigned int stride = blockDim.y / 2; stride > 0; stride >>= 1) { | |
if((row < stride) && (row + stride < num_rows)) { | |
sdata[row] += sdata[row + stride]; | |
} | |
__syncthreads(); | |
} | |
// --------------------------------------- | |
// Subtract from columns | |
d_x[offset] = d_x[offset] - max_value - log(sdata[0]); | |
} | |
void d_NormProb(int num_rows, int num_cols, double * d_x) { | |
dim3 block(num_cols, 1, 1); | |
dim3 thread(1, ceil_pow2(num_rows), 1); | |
int shmem_size = num_rows * sizeof(double); | |
__NormProb<<<block, thread, shmem_size>>>(d_x, num_cols, num_rows); | |
cudaDeviceSynchronize(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment