Skip to content

Instantly share code, notes, and snippets.

@xukai92
Created October 26, 2020 10:14
Show Gist options
  • Save xukai92/5fdbc481e4c3360561d972647a690acd to your computer and use it in GitHub Desktop.
Save xukai92/5fdbc481e4c3360561d972647a690acd to your computer and use it in GitHub Desktop.
AHMC GPU improvements
using Random: AbstractRNG, GLOBAL_RNG
using AdvancedHMC: PhasePoint, phasepoint
function AdvancedHMC.refresh(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
z::PhasePoint{T},
h::Hamiltonian
) where {T<:CuArray}
r = CuArray{Float32, 2}(undef, size(h.metric)...)
CUDA.CURAND.randn!(r)
return phasepoint(h, z.θ, r)
end
function AdvancedHMC.mh_accept_ratio(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, Horiginal::TA, Hproposal::TA,
) where {T<:AbstractFloat, TA<:CuArray{<:T}}
α = min.(one(T), exp.(Horiginal .- Hproposal))
# NOTE: There is a chance that sharing the RNG over multiple
# chains for accepting / rejecting might couple
# the chains. We need to revisit this more rigirously
# in the future. See discussions at
# https://github.com/TuringLang/AdvancedHMC.jl/pull/166#pullrequestreview-367216534
accept = TA(rand(rng, T, length(Horiginal))) .< α
return accept, α
end
function AdvancedHMC.accept_phasepoint!(z::T, z′::T, is_accept) where {T<:PhasePoint{<:AbstractMatrix}}
# Revert unaccepted proposals in `z′`
is_reject = (!).(is_accept)
if any(is_reject)
# Convert logical indexing to number indexing to support CUDA.jl
# NOTE: for x::CuArray, x[:,Vector{Bool}] is NOT supported
# x[:,CuVector{Int}] is NOT supported
# x[:,Vector{Int}] is supported
is_reject = findall(is_reject) |> Array
z′.θ[:,is_reject] = z.θ[:,is_reject]
z′.r[:,is_reject] = z.r[:,is_reject]
z′.ℓπ.value[is_reject] = z.ℓπ.value[is_reject]
z′.ℓπ.gradient[:,is_reject] = z.ℓπ.gradient[:,is_reject]
z′.ℓκ.value[is_reject] = z.ℓκ.value[is_reject]
z′.ℓκ.gradient[:,is_reject] = z.ℓκ.gradient[:,is_reject]
end
# Always return `z′` as any unaccepted proposal is already reverted
# NOTE: This in place treatment of `z′` is for memory efficient consideration.
# We can also copy `z′ and avoid mutating the original `z′`. But this is
# not efficient and immutability of `z′` is not important in this local scope.
return z′
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment