Skip to content

Instantly share code, notes, and snippets.

@jekbradbury
Last active June 11, 2018 03:42
Show Gist options
  • Save jekbradbury/becd9fff1c469dc73cb57f73d1f2a439 to your computer and use it in GitHub Desktop.
Save jekbradbury/becd9fff1c469dc73cb57f73d1f2a439 to your computer and use it in GitHub Desktop.
CUDAnative softmax, ported from Marian-NMT
# ported from Marian-NMT
# https://github.com/marian-nmt/marian-dev/blob/8fbfa656/src/tensors/gpu/tensor_operators.cu#L206-L320
# licensed under MIT
using CUDAdrv
using CUDAnative
using BenchmarkTools
const MAX_THREADS = 256 # seems to work best (1024 max)
const MAX_BLOCKS = 2^31 - 1 # benchmark only exercises 2048
function reduce_block(arr::CuDeviceArray, f)
sync_threads()
len = blockDim().x
while len != 1
sync_threads()
skip = (len + 1) >> 1
if threadIdx().x <= (len >> 1)
arr[threadIdx().x] = f(arr[threadIdx().x],
arr[threadIdx().x + skip])
end
len = (len + 1) >> 1
end
sync_threads()
end
function softmax_kernel(arr::CuDeviceArray{T}, out::CuDeviceArray{T},
mask::Union{CuDeviceArray{T}, Nothing}) where {T}
rows = size(out, 1) # "rows" is the first dimension
cols = length(out) ÷ rows # "cols" are dimensions 2:end
broadcast = (mask !== nothing && size(out) != size(mask))
# serial map by cols ÷ blocks
for colbase in 0:gridDim().x:(cols-1)
col = colbase + blockIdx().x
col > cols && return
# serial reduction (max) by rows ÷ threads
_max = @cuDynamicSharedMem(T, blockDim().x, blockDim().x * sizeof(T))
_max[threadIdx().x] = -typemax(T)
for rowbase in 0:blockDim().x:(rows-1)
row = rowbase + threadIdx().x
row > rows && return
idx = row + (col-1) * rows
mask_val = T(1)
if mask !== nothing
mask_idx = idx
if broadcast
mask_idx = Broadcast.newindex(
mask, CartesianIndices(out)[mask_idx])
end
mask_val = mask[mask_idx]
end
in_val = arr[idx]
if mask_val != 0 && in_val > _max[threadIdx().x]
_max[threadIdx().x] = in_val
end
end
# block-parallel reduction (max) by threads
reduce_block(_max, max)
colmax = _max[1] # overall max of this column
sync_threads()
# serial reduction (sum) by rows ÷ threads
# fused over an elementwise exp and subtraction
_sum = @cuDynamicSharedMem(T, blockDim().x)
_sum[threadIdx().x] = T(0)
for rowbase in 0:blockDim().x:(rows-1)
row = rowbase + threadIdx().x
row > rows && continue
idx = row + (col-1) * rows
mask_val = T(1)
if mask !== nothing
mask_idx = idx
if broadcast
# TODO lift out of the loop
mask_idx = Broadcast.newindex(
mask, CartesianIndices(out)[mask_idx])
end
mask_val = mask[mask_idx]
end
ex = T(0)
if mask_val != 0
ex = CUDAnative.exp_fast(arr[idx] - colmax)
end
#out[idx] = ex
_sum[threadIdx().x] += ex
end
# block-parallel reduction (sum) by threads
reduce_block(_sum, +)
# _sum[0] is the overall sum of this column
# broadcasted division of out by _sum[0]
for rowbase in 0:blockDim().x:(rows-1)
row = rowbase + threadIdx().x
row > rows && continue
idx = row + (col-1) * rows
out[idx] = CUDAnative.exp_fast(arr[idx] - colmax) / _sum[1]
end
end
end
function softmax!(arr::CuArray{T}, out::CuArray{T};
mask::Union{CuArray{T}, Nothing}=nothing) where {T}
#device!(out.device)
rows = size(out, 1)
cols = length(out) ÷ rows
blks = min(MAX_BLOCKS, cols)
thrds = min(MAX_THREADS, rows)
shared = sizeof(T) * thrds * 2
@cuda blocks=blks threads=thrds shmem=shared softmax_kernel(
arr, out, mask)
end
softmax(arr; mask) = softmax!(arr, similar(arr); mask=mask)
include("softmax.jl")
a = CuArray(rand(Float32, (3, 4)))
b = CuArray(zeros(Float32, (3, 4)))
display(a)
softmax!(a, b)
display(b)
a = CuArray(rand(Float32, (16384, 64*32)));
b = CuArray(zeros(Float32, (16384, 64*32)));
@btime begin softmax!(a, b); synchronize() end
# V100: 826.665 μs (25 allocations: 720 bytes)
# vs cuDNN time of 820 μs (see https://github.com/JuliaGPU/GPUArrays.jl/issues/96)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment