Last active
July 10, 2023 14:55
-
-
Save jiahao/de49d4c948676e07c234217e4a7b957a to your computer and use it in GitHub Desktop.
Minimal Julia implementation of NF4 floating point for QLoRA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Statistics | |
using BFloat16s | |
using StaticArrays | |
import Base: getindex, setindex!, length, iterate | |
########################################### | |
# Implementation of the NormedFloat4 type | |
# and its container type, QLoRAArray | |
# | |
# Ref: | |
# https://arxiv.org/abs/2305.14314 | |
# https://github.com/artidoro/qlora | |
########################################### | |
# A data type for one-dimensional UInt4 arrays that index like a normal array but are stored in packed nibbles | |
struct PackedUInt4Array{N} | |
data::MVector{N, UInt8} | |
end | |
function PackedUInt4Array(N) #This N is the logical size | |
Nbytes = ceil(Int, N/2) | |
data = Vector{UInt8}(undef, Nbytes) | |
PackedUInt4Array{Nbytes}(data) | |
end | |
get_high_nibble(byte::UInt8) = byte >> 4 | |
get_low_nibble(byte::UInt8) = byte & 0x0f | |
pack_high_nibble(byte::UInt8, v::UInt8) = (v<<4) + byte & 0x0f | |
pack_low_nibble(byte::UInt8, v::UInt8) = v + byte & 0xf0 | |
pack_nibbles(hi::UInt8, lo::UInt8) = (hi<<4) + lo | |
pack_nibbles(UInt8(0xd), UInt8(0x8)) | |
function getindex(A::PackedUInt4Array, idx::Integer) | |
offset = (idx-1) >> 1 | |
nibble_idx = (idx-1) & 0x1 | |
byte = A.data[offset+1] | |
if nibble_idx == 0 | |
nibble = get_high_nibble(byte) | |
else #nibble_idx == 1 | |
nibble = get_low_nibble(byte) | |
end | |
return nibble | |
end | |
function setindex!(A::PackedUInt4Array, X::UInt8, idx::Integer) | |
# Silently truncate input | |
z = X & 0xf | |
offset = (idx-1) >> 1 | |
nibble_idx = (idx-1) & 0x1 | |
byte = A.data[offset+1] | |
if nibble_idx == 0 | |
new_byte = pack_high_nibble(byte, z) | |
else #nibble_idx == 1 | |
new_byte = pack_low_nibble(byte, z) | |
end | |
A.data[offset+1] = new_byte | |
end | |
struct QLoRAArray{T, N} | |
μ::T | |
σ::T | |
scale::T | |
data::PackedUInt4Array | |
end | |
const NF4Quantiles = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, | |
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0] | |
function quantizeNF4(x) | |
# Appendix E | |
q = searchsortedfirst(NF4Quantiles, x) | |
if q > 1 && (x - NF4Quantiles[q-1] < NF4Quantiles[q] - x) # RoundNearest | |
q -= 1 | |
end | |
UInt8(max(q-1, 0)) | |
end | |
function QLoRAArray(x::AbstractArray{T}) where T | |
# Standardize x - could probably use StatsBase.ZScoreTransform | |
μ, σ = mean(x), std(x) | |
z = (x .- μ) ./ σ | |
# Maximum absolute scaling | |
zmin, zmax = extrema(z) | |
scale = 1/max(-zmin, zmax) | |
N = length(x) | |
data = PackedUInt4Array(N) | |
for i in eachindex(z) | |
data[i] = quantizeNF4(z[i]*scale) | |
end | |
QLoRAArray{T, N}(μ, σ, scale, data) | |
end | |
function getindex(A::QLoRAArray{T}, idx::Integer) where T | |
v = A.data[idx] | |
z = NF4Quantiles[v+1]/A.scale | |
x = z*A.σ + A.μ | |
return x | |
end | |
length(::QLoRAArray{T, N}) where {T, N} = N | |
iterate(A::QLoRAArray{T, N}) where {T, N} = N==0 ? nothing : (A[1], 1) | |
iterate(A::QLoRAArray{T, N}, i) where {T, N} = i==N ? nothing : (A[i+1], i+1) | |
########################################### | |
# Tiny demo | |
########################################### | |
v = randn(BFloat16, 80) | |
z = QLoRAArray(v) | |
println("Quantization MAE = ", maximum(abs.(v.-z))) | |
using Plots | |
plot([v, BFloat16.(z)]) | |
########################################### | |
# Implementation of the Float8 floating point types | |
# | |
# NVIDIA H100s now support two 8-bit floating point formats | |
# with different bit lenghts for exponent and mantissa/significand | |
# | |
# Why, AI people?? WHY?? | |
# | |
# Ref: | |
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html | |
# https://arxiv.org/abs/2209.05433 | |
# | |
########################################### | |
primitive type Float8_E4M3 <: AbstractFloat 8 end | |
primitive type Float8_E5M2 <: AbstractFloat 8 end | |
const Float8 = Float8_E4M3 # Looks like this is the more common one for forward pass | |
const Float8s = Union{Float8_E4M3, Float8_E5M2} | |
Base.sign_mask(::Type{Float8_E4M3}) = 0x80 | |
Base.sign_mask(::Type{Float8_E5M2}) = 0x80 | |
Base.exponent_one(::Type{Float8_E4M3}) = 0x3b | |
Base.exponent_one(::Type{Float8_E5M2}) = 0x3c | |
Base.exponent_half(::Type{Float8_E4M3}) = 0x30 | |
Base.exponent_half(::Type{Float8_E5M2}) = 0x38 | |
Base.exponent_bias(::Type{Float8_E4M3}) = 7 | |
Base.exponent_bias(::Type{Float8_E5M2}) = 15 | |
Base.exponent_bits(::Type{Float8_E4M3}) = 4 | |
Base.exponent_bits(::Type{Float8_E5M2}) = 5 | |
Base.exponent_mask(::Type{Float8_E4M3}) = 0b0111_1000 | |
Base.exponent_mask(::Type{Float8_E5M2}) = 0b0111_1100 | |
Base.significand_bits(::Type{Float8_E4M3}) = 3 | |
Base.significand_bits(::Type{Float8_E5M2}) = 2 | |
Base.significand_mask(::Type{Float8_E4M3}) = 0b000_00111 | |
Base.significand_mask(::Type{Float8_E5M2}) = 0b000_00011 | |
Base.signbit(x::Float8s) = (reinterpret(UInt8, x) & 0x80) !== 0x00 | |
# The infinities | |
Base.isinf(x::Type{Float8_E4M3}) = false | |
Base.isinf(x::Type{Float8_E5M2}) = (x & 0b0111_1100) == 0b0111_1100 | |
# The NaNs | |
const NaN8_mask = 0b0111_1111 | |
Base.isnan(x::Type{Float8_E4M3}) = (x & NaN8_mask) == NaN8_mask | |
Base.isnan(x::Type{Float8_E5M2}) = !isinf(x) && ((x & NaN8_mask) == NaN8_mask) | |
const NaN8_E4M3 = reinterpret(Float8_E4M3, NaN8_mask) | |
const NaN8_E5M2 = reinterpret(Float8_E5M2, NaN8_mask) | |
# Inf8_E4M3 does not exist | |
const Inf8_E5M2 = reinterpret(Float8_E5M2, 0b0111_1100) | |
Base.floatmax(::Type{Float8_E4M3}) = reinterpret(Float8_E4M3, 0b0111_1110) | |
Base.floatmax(::Type{Float8_E5M2}) = reinterpret(Float8_E5M2, 0b0111_1011) | |
Base.floatmin(::Type{Float8_E5M2}) = reinterpret(Float8_E4M3, 0b0000_1000) | |
Base.floatmin(::Type{Float8_E5M2}) = reinterpret(Float8_E5M2, 0b0000_0100) | |
Base.typemin(::Type{Float8_E4M3}) = -Base.floatmax(Float8_E4M3) | |
Base.typemin(::Type{Float8_E5M2}) = -Inf8_E5M2 | |
Base.typemax(::Type{Float8_E4M3}) = Base.floatmax(Float8_E4M3) | |
Base.typemax(::Type{Float8_E5M2}) = Inf8_E5M2 | |
Base.promote_rule(::Type{S}, ::Type{T}) where {S<:AbstractFloat,T<:Float8s} = S | |
Base.promote_rule(::Type{Float8_E4M3}, ::Type{Float8_E5M2}) = Float8_E5M2 | |
######################################################## | |
# Interconversion with other floats (primarily B/Float16 | |
######################################################## | |
# E5M2 has same structure as Float16 but with last 8 bits of mantissa lopped off | |
Base.convert(::Type{Float8_E5M2}, x::AbstractFloat) = ( | |
reinterpret(Float8_E5M2, UInt8(reinterpret(UInt16, Float16(x)) >> 8))) | |
Float8_E5M2(x) = Base.convert(Float8_E5M2, x) | |
Base.convert(::Type{Float8_E4M3}, x::AbstractFloat) = convert(Float8_E4M3, Float16(x)) | |
Float8_E4M3(x) = Base.convert(Float8_E4M3, x) | |
Base.convert(::Type{Float8_E4M3}, x) = Base.convert(Float8_E4M3, Float16(x)) | |
function Base.convert(::Type{Float8_E4M3}, x::Float16) | |
isnan(x) && return NaN8_E4M3 | |
isinf(x) && error("Infinities not representable") | |
z = reinterpret(UInt16, x) | |
sign_bits = UInt8((Base.sign_mask(Float16) & z) >> 8) | |
significand_bitshift = Base.significand_bits(Float16) - Base.significand_bits(Float8_E4M3) | |
significand_bits = UInt8((Base.significand_mask(Float16) & z) >> significand_bitshift) | |
exponent = (Base.exponent_mask(Float16) & z) >> Base.significand_bits(Float16) - Base.exponent_bias(Float16) | |
exponent_bits = UInt8(exponent + Base.exponent_bias(Float8_E4M3)) << Base.significand_bits(Float8_E4M3) | |
final_bits = sign_bits | exponent_bits | significand_bits | |
return reinterpret(Float8_E4M3, final_bits) | |
end | |
Base.Float16(x::Float8s) = Base.convert(Float16, x) | |
Base.Float32(x::Float8s) = Float32(Float16(x)) | |
Base.Float64(x::Float8s) = Float64(Float16(x)) | |
function Base.convert(::Type{Float16}, x::Float8s) | |
T = typeof(x) | |
isnan(x) && return NaN16 | |
z = reinterpret(UInt8, x) | |
sign_bits = UInt16(Base.sign_mask(T) & z) << 8 | |
significand_bitshift = Base.significand_bits(Float16) - Base.significand_bits(T) | |
significand_bits = (UInt16(Base.significand_mask(T) & z) << significand_bitshift) | |
exponent = (Base.exponent_mask(T) & z) >> Base.significand_bits(T) - Base.exponent_bias(T) | |
exponent_bits = UInt16(exponent + Base.exponent_bias(Float16)) << Base.significand_bits(Float16) | |
final_bits = sign_bits | exponent_bits | significand_bits | |
return reinterpret(Float16, final_bits) | |
end | |
z = convert(Float8_E4M3, 1.25) | |
########################################### | |
# A very incomplete implementation of arithmetic mediated by Float16s | |
# The docs seem to indicate that the Float8s are storage only and | |
# native arithmetic is in some 16-point format, possible BFloat16s | |
########################################### | |
import Base: +, -, *, / | |
+(x::T, y::T) where T<:Float8s = T(Float16(x) + Float16(y)) | |
-(x::T, y::T) where T<:Float8s = T(Float16(x) - Float16(y)) | |
*(x::T, y::T) where T<:Float8s = T(Float16(x) * Float16(y)) | |
/(x::T, y::T) where T<:Float8s = T(Float16(x) / Float16(y)) | |
########################################### | |
# An incomplete implementation of the doubly quantized arrays | |
# used by QLoRA | |
########################################### | |
struct ChunkedQuantizedArray{Nx, Ny, S, Nb, B, T} <: StaticArray{Tuple{Nx, Ny}, S, 2} | |
c :: SVector{Nb, S} | |
data :: Matrix{T} | |
end | |
#(1) of arXiv:2305.14314v1 | |
function quantize(::Type{T}, X::AbstractMatrix{S}, | |
B::Int=64, # Block size | |
rm::RoundingMode = RoundNearest) where {S,T} | |
Nx, Ny = size(X) | |
Nb = round(Int, length(X)/B, RoundUp) #Number of blocks | |
c = @MVector zeros(S, Nb) | |
data = Array{T}(undef, Nb, B) | |
for ib in 1:Nb #Iterate over blocks | |
Xb = view(X, ((ib-1)*B+1):min(ib*B, length(X))) | |
c[ib] = typemax(T) / maximum(abs.(Xb)) | |
data[ib, 1:length(Xb)] = round.(T, c[ib]*Xb, rm) | |
end | |
ChunkedQuantizedArray{Nx, Ny, S, Nb, B, T}(c, data) | |
end | |
function getindex(X::ChunkedQuantizedArray{Nx, Ny, S, Nb, B, T}, idx::Int) where {Nx, Ny, S, Nb, B, T} | |
I = CartesianIndices((1:Nb, 1:B))[idx] | |
X.data[I] / X.c[I[1]] | |
end | |
function doublequantize(::Type{T2}, ::Type{T}, X::AbstractMatrix{S}, | |
B::Int=64, B2::Int=256, # Block size | |
rm::RoundingMode = RoundNearest) where {S,T,T2} | |
Xc = quantize(T, X, B, rm) | |
Xc2 = quantize(T2, reshape(Xc.c, length(Xc.c), 1), B2, rm) | |
(Xc2.c, Xc2.data, Xc.data) | |
end | |
########################################### | |
# An skeleton of the low-rank adapter | |
# "module" | |
########################################### | |
struct LowRankAdapter | |
# X*W is a low rank projection (factorization) of Y | |
W # size (h, o) | |
s | |
L1 # size (h, r) | |
L2 # size (r, o) | |
end | |
# This is how LowRankAdapter acts on input X in the forward pass | |
# TODO Figure out the canonical spelling in MLJ or Flux | |
# | |
function forward(A::LowRankAdapter, X) | |
Y = X * A.W + A.s * X * A.L1 * A.L2 | |
return Y | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
on L#67, I think it should be
q>1
otherwise you'd be indexing with zero for q=1