Last active
June 11, 2018 03:42
-
-
Save jekbradbury/becd9fff1c469dc73cb57f73d1f2a439 to your computer and use it in GitHub Desktop.
CUDAnative softmax, ported from Marian-NMT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ported from Marian-NMT | |
# https://github.com/marian-nmt/marian-dev/blob/8fbfa656/src/tensors/gpu/tensor_operators.cu#L206-L320 | |
# licensed under MIT | |
using CUDAdrv | |
using CUDAnative | |
using BenchmarkTools | |
const MAX_THREADS = 256 # seems to work best (1024 max) | |
const MAX_BLOCKS = 2^31 - 1 # benchmark only exercises 2048 | |
function reduce_block(arr::CuDeviceArray, f) | |
sync_threads() | |
len = blockDim().x | |
while len != 1 | |
sync_threads() | |
skip = (len + 1) >> 1 | |
if threadIdx().x <= (len >> 1) | |
arr[threadIdx().x] = f(arr[threadIdx().x], | |
arr[threadIdx().x + skip]) | |
end | |
len = (len + 1) >> 1 | |
end | |
sync_threads() | |
end | |
function softmax_kernel(arr::CuDeviceArray{T}, out::CuDeviceArray{T}, | |
mask::Union{CuDeviceArray{T}, Nothing}) where {T} | |
rows = size(out, 1) # "rows" is the first dimension | |
cols = length(out) ÷ rows # "cols" are dimensions 2:end | |
broadcast = (mask !== nothing && size(out) != size(mask)) | |
# serial map by cols ÷ blocks | |
for colbase in 0:gridDim().x:(cols-1) | |
col = colbase + blockIdx().x | |
col > cols && return | |
# serial reduction (max) by rows ÷ threads | |
_max = @cuDynamicSharedMem(T, blockDim().x, blockDim().x * sizeof(T)) | |
_max[threadIdx().x] = -typemax(T) | |
for rowbase in 0:blockDim().x:(rows-1) | |
row = rowbase + threadIdx().x | |
row > rows && return | |
idx = row + (col-1) * rows | |
mask_val = T(1) | |
if mask !== nothing | |
mask_idx = idx | |
if broadcast | |
mask_idx = Broadcast.newindex( | |
mask, CartesianIndices(out)[mask_idx]) | |
end | |
mask_val = mask[mask_idx] | |
end | |
in_val = arr[idx] | |
if mask_val != 0 && in_val > _max[threadIdx().x] | |
_max[threadIdx().x] = in_val | |
end | |
end | |
# block-parallel reduction (max) by threads | |
reduce_block(_max, max) | |
colmax = _max[1] # overall max of this column | |
sync_threads() | |
# serial reduction (sum) by rows ÷ threads | |
# fused over an elementwise exp and subtraction | |
_sum = @cuDynamicSharedMem(T, blockDim().x) | |
_sum[threadIdx().x] = T(0) | |
for rowbase in 0:blockDim().x:(rows-1) | |
row = rowbase + threadIdx().x | |
row > rows && continue | |
idx = row + (col-1) * rows | |
mask_val = T(1) | |
if mask !== nothing | |
mask_idx = idx | |
if broadcast | |
# TODO lift out of the loop | |
mask_idx = Broadcast.newindex( | |
mask, CartesianIndices(out)[mask_idx]) | |
end | |
mask_val = mask[mask_idx] | |
end | |
ex = T(0) | |
if mask_val != 0 | |
ex = CUDAnative.exp_fast(arr[idx] - colmax) | |
end | |
#out[idx] = ex | |
_sum[threadIdx().x] += ex | |
end | |
# block-parallel reduction (sum) by threads | |
reduce_block(_sum, +) | |
# _sum[0] is the overall sum of this column | |
# broadcasted division of out by _sum[0] | |
for rowbase in 0:blockDim().x:(rows-1) | |
row = rowbase + threadIdx().x | |
row > rows && continue | |
idx = row + (col-1) * rows | |
out[idx] = CUDAnative.exp_fast(arr[idx] - colmax) / _sum[1] | |
end | |
end | |
end | |
function softmax!(arr::CuArray{T}, out::CuArray{T}; | |
mask::Union{CuArray{T}, Nothing}=nothing) where {T} | |
#device!(out.device) | |
rows = size(out, 1) | |
cols = length(out) ÷ rows | |
blks = min(MAX_BLOCKS, cols) | |
thrds = min(MAX_THREADS, rows) | |
shared = sizeof(T) * thrds * 2 | |
@cuda blocks=blks threads=thrds shmem=shared softmax_kernel( | |
arr, out, mask) | |
end | |
softmax(arr; mask) = softmax!(arr, similar(arr); mask=mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
include("softmax.jl") | |
a = CuArray(rand(Float32, (3, 4))) | |
b = CuArray(zeros(Float32, (3, 4))) | |
display(a) | |
softmax!(a, b) | |
display(b) | |
a = CuArray(rand(Float32, (16384, 64*32))); | |
b = CuArray(zeros(Float32, (16384, 64*32))); | |
@btime begin softmax!(a, b); synchronize() end | |
# V100: 826.665 μs (25 allocations: 720 bytes) | |
# vs cuDNN time of 820 μs (see https://github.com/JuliaGPU/GPUArrays.jl/issues/96) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment