Last active
October 25, 2024 16:47
-
-
Save rygorous/ea042174cc289c3153876d2cace970d2 to your computer and use it in GitHub Desktop.
BC4 interpolator results to float conversion attempt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Computes the 32-bit IEEE float bit pattern for x/16320 (where x is given as an integer), | |
# with rounding slightly off from proper RN, matching observed behavior of AMD BC4_UNORM | |
# decoder HW | |
def hw_14bit_to_floatu_v2(x, trace=False): | |
if x <= 0: | |
return 0 | |
elif x >= 255 * 64: | |
return 0x3f800000 | |
else: | |
# 16320 = 255*64 | |
# Divide by 64 is a freebie, so we're left having to divide by 255 | |
# 1/255 = 1/256 * 256/255 = 1/256 * (1 / (255/256)) | |
# = 1/256 * (1 / (1 - 1/256)) | |
# which by the usual sum formula applied in reverse is a geometric series: | |
# = 1/256 + 1/(256^2) + 1/(256^3) + ... | |
# | |
# After the leading few bits, this means we have a period of (at most) 8 bits, | |
# and we can compute the repeated bits directly (this will be "mid"). We can | |
# then later build the mantissa and round with full knowledge of the infinite | |
# series. | |
# Top and middle bits of infinite series expansion | |
if x >= 64: # <8 leading zeros | |
base_exp = 126 # biased exponent for the [0.5,1) interval, the largest we can end up in | |
top = x >> 6 # 1 <= top < 255 since 64 <= x < 255*64 | |
# The "middle" bits are a repeating 8-bit pattern: the top bits of the | |
# next-lower term in the series expansion, plus the remaining 6 low bits from | |
# the higher-order term | |
mid = top + ((x & 0x3f) << 2) | |
if mid >= 256: # End-around carry for middle term | |
mid -= 255 # -256 to remove the carry going out, +1 to add the carry in at the bottom (since the next-lower term carries into us) | |
top += 1 # carry into top bits (will not carry out since top<255) | |
else: # >=8 leading zeros | |
base_exp = 126 - 8 | |
# No carries possible in this case | |
top = x << 2 # top > 0 since x > 0 | |
mid = top | |
# Leading zero count on top to figure out normalization; | |
# we ensured top != 0 | |
lzcnt = 0 | |
while (top << lzcnt) < (1 << 7): | |
lzcnt += 1 | |
# Rotate 8 middle bits to align with lzcnt | |
mid_rot = ((mid << lzcnt) | (mid >> (8 - lzcnt))) & 0xff | |
# Merged normalized mantissa | |
# this produces a normalized 24-bit mantissa with top bit set | |
mid_in_top_mask = (1 << lzcnt) - 1 # mask for bits below the post-normalize "top" | |
nmant_top = (top << lzcnt) | (mid_rot & mid_in_top_mask) | |
nmant = (nmant_top << 16) | (mid_rot << 8) | mid_rot | |
assert (1 << 23) <= nmant < (1 << 24) | |
# Build preliminary result | |
result = ((base_exp - lzcnt) << 23) | (nmant & 0x7fffff) | |
# Round to nearest | |
# | |
# we don't need to worry about exact half-way cases; if the rounding | |
# bit (which is the MSB of mid_rot) is 1, then (because it repeats) we | |
# have an infinite number of 1 bits further down the expansion, so we're | |
# never exactly halfway and don't need to worry about tie-breaker cases. | |
# | |
# Rounding can add 1; this can bump a nmant of (1<<23)-1 to (1<<24), | |
# which means we need to increment the exponent and set the mantissa to 0. | |
# This is traditionally handled by adding onto the binary representation | |
# of the float, which does the right thing. | |
# | |
# (This could once again be simplified slightly by using the known structure | |
# of the mantissa; either mid_rot = mid = 0xff and a rounding add will carry | |
# all the way into the top 8 bits of the mantissa, leaving the lower bits 0, | |
# or a rounding increment will not propagate past the low 8 bits.) | |
result += mid_rot >> 7 | |
# Magic futzing | |
# without these adjustments that undo the above rounding term in certain cases, | |
# this calculation is exact. | |
# | |
# in short these are basically just random expressions that determine where | |
# to round incorrect to match the HW, since so far I haven't figured out the | |
# system behind these | |
if mid_rot == 0x80 and top < 0x40 and (top & 3) == 0 and (top & (top - 1)) != 0: | |
result -= 1 | |
if mid_rot >= 0x80 and base_exp == 126 - 8 and lzcnt == 2: | |
result -= 1 | |
if trace: | |
orig_top = x >> 6 | |
orig_mid = orig_top + ((x & 0x3f) << 2) | |
print('top=0x{:x} mid=0x{:02x} lzcnt={:d} mid_rot=0x{:02x} nmant=0x{:x} orig_top=0x{:02x} orig_mid=0x{:02x}'.format(top, mid, lzcnt, mid_rot, nmant, orig_top, orig_mid)) | |
return result | |
# Computes the 32-bit IEEE float bit pattern for x/(127*64) (where x is given as an integer), | |
# with rounding slightly off from proper RN, matching observed behavior of AMD BC4_SNORM | |
# decoder HW | |
def fullprec_14bit_to_floats(x, trace=False): | |
# Sign handling: negative values remember the sign | |
# bit but continue working on the absolute value | |
sign_bit = 0 | |
if x < 0: | |
x = -x | |
sign_bit = 0x80000000 | |
if x == 0: | |
return 0 | |
elif x >= 127 * 64: | |
return sign_bit | 0x3f800000 | |
else: | |
# The series expansion here works the same way as the unsigned case, expect | |
# we divide by 127*64, and the 127 instead of 255 leads to the series | |
# x/128 + x/(128^2) + x/(128^3) + ... | |
# Top and middle bits of infinite series expansion | |
if x >= 64: # <7 leading zeros | |
base_exp = 126 | |
top = x >> 6 # 1 <= top < 127 by bounds on x | |
# The "middle" bits are a repeating 7-bit pattern: the top bits of the | |
# next-lower term in the series expansion, plus the remaining 6 low bits from | |
# the higher-order term | |
mid = top + ((x & 0x3f) << 1) | |
if mid >= 128: | |
mid -= 127 # -128 to negate wraparound, +1 for carry coming in from next-lower term | |
top += 1 # carry into top too | |
else: # >=7 leading zeros | |
base_exp = 126 - 7 | |
# No carries possible in this case | |
top = x << 1 # top > 0 since x > 0 | |
mid = top | |
# Leading zero count on top to figure out normalization; we ensured top != 0 | |
lzcnt = 0 | |
while (top << lzcnt) < (1 << 6): | |
lzcnt += 1 | |
# Rotate middle 7 bits to align with lzcnt | |
mid_rot = ((mid << lzcnt) | (mid >> (7 - lzcnt))) & 0x7f | |
# Merged normalized mantissa; 24-bit mantissa with explicit 1 bit | |
mid_in_top_mask = (1 << lzcnt) - 1 | |
nmant_top = (top << lzcnt) | (mid_rot & mid_in_top_mask) | |
nmant = (nmant_top << 17) | (mid_rot << 10) | (mid_rot << 3) | (mid_rot >> 4) | |
# Preliminary result | |
result = sign_bit | ((base_exp - lzcnt) << 23) | (nmant & 0x7fffff) | |
# Round to nearest | |
# | |
# we don't need to worry about exact half-way cases; if the rounding bit is set, | |
# we have an infinite string of copies of that bit in lower-order terms of the infinite | |
# series, so we are strictly above >1/2 and should always round up. | |
# | |
# Rounding can add 1, and if all mantissa bits were set, adding 1 to the bit pattern | |
# for the float clears the mantissa bits and bumps the exponent up by 1, which is the | |
# correct thing to do. (This is the traditional way of handling rounding here.) | |
result += (mid_rot >> 3) & 1 | |
if trace: | |
print('top=0x{:x} mid=0x{:02x} lzcnt={:d} mid_rot=0x{:02x} nmant=0x{:x}'.format(top, mid, lzcnt, mid_rot, nmant)) | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment