Skip to content

Instantly share code, notes, and snippets.

@rygorous
Last active October 25, 2024 16:47
Show Gist options
  • Save rygorous/ea042174cc289c3153876d2cace970d2 to your computer and use it in GitHub Desktop.
Save rygorous/ea042174cc289c3153876d2cace970d2 to your computer and use it in GitHub Desktop.
BC4 interpolator results to float conversion attempt
# Computes the 32-bit IEEE float bit pattern for x/16320 (where x is given as an integer),
# with rounding slightly off from proper RN, matching observed behavior of AMD BC4_UNORM
# decoder HW
def hw_14bit_to_floatu_v2(x, trace=False):
if x <= 0:
return 0
elif x >= 255 * 64:
return 0x3f800000
else:
# 16320 = 255*64
# Divide by 64 is a freebie, so we're left having to divide by 255
# 1/255 = 1/256 * 256/255 = 1/256 * (1 / (255/256))
# = 1/256 * (1 / (1 - 1/256))
# which by the usual sum formula applied in reverse is a geometric series:
# = 1/256 + 1/(256^2) + 1/(256^3) + ...
#
# After the leading few bits, this means we have a period of (at most) 8 bits,
# and we can compute the repeated bits directly (this will be "mid"). We can
# then later build the mantissa and round with full knowledge of the infinite
# series.
# Top and middle bits of infinite series expansion
if x >= 64: # <8 leading zeros
base_exp = 126 # biased exponent for the [0.5,1) interval, the largest we can end up in
top = x >> 6 # 1 <= top < 255 since 64 <= x < 255*64
# The "middle" bits are a repeating 8-bit pattern: the top bits of the
# next-lower term in the series expansion, plus the remaining 6 low bits from
# the higher-order term
mid = top + ((x & 0x3f) << 2)
if mid >= 256: # End-around carry for middle term
mid -= 255 # -256 to remove the carry going out, +1 to add the carry in at the bottom (since the next-lower term carries into us)
top += 1 # carry into top bits (will not carry out since top<255)
else: # >=8 leading zeros
base_exp = 126 - 8
# No carries possible in this case
top = x << 2 # top > 0 since x > 0
mid = top
# Leading zero count on top to figure out normalization;
# we ensured top != 0
lzcnt = 0
while (top << lzcnt) < (1 << 7):
lzcnt += 1
# Rotate 8 middle bits to align with lzcnt
mid_rot = ((mid << lzcnt) | (mid >> (8 - lzcnt))) & 0xff
# Merged normalized mantissa
# this produces a normalized 24-bit mantissa with top bit set
mid_in_top_mask = (1 << lzcnt) - 1 # mask for bits below the post-normalize "top"
nmant_top = (top << lzcnt) | (mid_rot & mid_in_top_mask)
nmant = (nmant_top << 16) | (mid_rot << 8) | mid_rot
assert (1 << 23) <= nmant < (1 << 24)
# Build preliminary result
result = ((base_exp - lzcnt) << 23) | (nmant & 0x7fffff)
# Round to nearest
#
# we don't need to worry about exact half-way cases; if the rounding
# bit (which is the MSB of mid_rot) is 1, then (because it repeats) we
# have an infinite number of 1 bits further down the expansion, so we're
# never exactly halfway and don't need to worry about tie-breaker cases.
#
# Rounding can add 1; this can bump a nmant of (1<<23)-1 to (1<<24),
# which means we need to increment the exponent and set the mantissa to 0.
# This is traditionally handled by adding onto the binary representation
# of the float, which does the right thing.
#
# (This could once again be simplified slightly by using the known structure
# of the mantissa; either mid_rot = mid = 0xff and a rounding add will carry
# all the way into the top 8 bits of the mantissa, leaving the lower bits 0,
# or a rounding increment will not propagate past the low 8 bits.)
result += mid_rot >> 7
# Magic futzing
# without these adjustments that undo the above rounding term in certain cases,
# this calculation is exact.
#
# in short these are basically just random expressions that determine where
# to round incorrect to match the HW, since so far I haven't figured out the
# system behind these
if mid_rot == 0x80 and top < 0x40 and (top & 3) == 0 and (top & (top - 1)) != 0:
result -= 1
if mid_rot >= 0x80 and base_exp == 126 - 8 and lzcnt == 2:
result -= 1
if trace:
orig_top = x >> 6
orig_mid = orig_top + ((x & 0x3f) << 2)
print('top=0x{:x} mid=0x{:02x} lzcnt={:d} mid_rot=0x{:02x} nmant=0x{:x} orig_top=0x{:02x} orig_mid=0x{:02x}'.format(top, mid, lzcnt, mid_rot, nmant, orig_top, orig_mid))
return result
# Computes the 32-bit IEEE float bit pattern for x/(127*64) (where x is given as an integer),
# with rounding slightly off from proper RN, matching observed behavior of AMD BC4_SNORM
# decoder HW
def fullprec_14bit_to_floats(x, trace=False):
# Sign handling: negative values remember the sign
# bit but continue working on the absolute value
sign_bit = 0
if x < 0:
x = -x
sign_bit = 0x80000000
if x == 0:
return 0
elif x >= 127 * 64:
return sign_bit | 0x3f800000
else:
# The series expansion here works the same way as the unsigned case, expect
# we divide by 127*64, and the 127 instead of 255 leads to the series
# x/128 + x/(128^2) + x/(128^3) + ...
# Top and middle bits of infinite series expansion
if x >= 64: # <7 leading zeros
base_exp = 126
top = x >> 6 # 1 <= top < 127 by bounds on x
# The "middle" bits are a repeating 7-bit pattern: the top bits of the
# next-lower term in the series expansion, plus the remaining 6 low bits from
# the higher-order term
mid = top + ((x & 0x3f) << 1)
if mid >= 128:
mid -= 127 # -128 to negate wraparound, +1 for carry coming in from next-lower term
top += 1 # carry into top too
else: # >=7 leading zeros
base_exp = 126 - 7
# No carries possible in this case
top = x << 1 # top > 0 since x > 0
mid = top
# Leading zero count on top to figure out normalization; we ensured top != 0
lzcnt = 0
while (top << lzcnt) < (1 << 6):
lzcnt += 1
# Rotate middle 7 bits to align with lzcnt
mid_rot = ((mid << lzcnt) | (mid >> (7 - lzcnt))) & 0x7f
# Merged normalized mantissa; 24-bit mantissa with explicit 1 bit
mid_in_top_mask = (1 << lzcnt) - 1
nmant_top = (top << lzcnt) | (mid_rot & mid_in_top_mask)
nmant = (nmant_top << 17) | (mid_rot << 10) | (mid_rot << 3) | (mid_rot >> 4)
# Preliminary result
result = sign_bit | ((base_exp - lzcnt) << 23) | (nmant & 0x7fffff)
# Round to nearest
#
# we don't need to worry about exact half-way cases; if the rounding bit is set,
# we have an infinite string of copies of that bit in lower-order terms of the infinite
# series, so we are strictly above >1/2 and should always round up.
#
# Rounding can add 1, and if all mantissa bits were set, adding 1 to the bit pattern
# for the float clears the mantissa bits and bumps the exponent up by 1, which is the
# correct thing to do. (This is the traditional way of handling rounding here.)
result += (mid_rot >> 3) & 1
if trace:
print('top=0x{:x} mid=0x{:02x} lzcnt={:d} mid_rot=0x{:02x} nmant=0x{:x}'.format(top, mid, lzcnt, mid_rot, nmant))
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment