Created
March 21, 2012 05:20
-
-
Save rygorous/2144712 to your computer and use it in GitHub Desktop.
half->float variants
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// half->float variants. | |
// by Fabian "ryg" Giesen. | |
// | |
// I hereby place this code in the public domain. | |
// | |
// half_to_float_fast: table based | |
// tables could be done in a more compact fashion (in particular, can store tab2 in low word of tab1!) | |
// but something of a dead end since not very SIMD-friendly. pretty much abandoned at this point. | |
// | |
// half_to_float_fast2: use FP adder hardware to deal with denormals. | |
// now this one has potential! (but needs some polish) | |
// | |
// half_to_float_fast3: same, but without bitfields. low number of constants involved. | |
// looking pretty good. need to check how it comes out in SSE2. | |
// | |
// half_to_float_fast4: again the same, but written to be easy to translate to SSE2 | |
// intrinsics. | |
// | |
// half_to_float4_SSE2: initial port of half_to_float_fast4 to SSE2. should do a second | |
// pass + clean-up, but not today. :) | |
// | |
// half_to_float4b_SSE2: some tweaks. should generate less ops and needs less constants; | |
// about 11% faster on my Sandy Bridge i7 using VS2010. YMMV. | |
// | |
// half_to_floast_fast5: slightly different approach, turns FP16 denormals into FP32 denormals. | |
// it's very slick and short but will be slower if denormals actually occur. | |
// | |
// half_to_float5_SSE2: SSE2-ified version of "fast5" variant. as said, in the presence of | |
// denormals, this will be noticably slower than variants 4/4b. use the included benchmarking | |
// code to find out by how much :). it's kinda serial, which means that even though it has a lot | |
// less instructions than variants 4/4b, it's not all that much faster even in the best case. | |
#include <stdio.h> | |
#include <emmintrin.h> | |
#include <intrin.h> | |
typedef unsigned int uint; | |
union FP32 | |
{ | |
uint u; | |
float f; | |
struct | |
{ | |
uint Mantissa : 23; | |
uint Exponent : 8; | |
uint Sign : 1; | |
}; | |
}; | |
union FP16 | |
{ | |
unsigned short u; | |
struct | |
{ | |
uint Mantissa : 10; | |
uint Exponent : 5; | |
uint Sign : 1; | |
}; | |
}; | |
static FP32 half_to_float_full(FP16 h) | |
{ | |
FP32 o = { 0 }; | |
// From ISPC ref code | |
if (h.Exponent == 0 && h.Mantissa == 0) // (Signed) zero | |
o.Sign = h.Sign; | |
else | |
{ | |
if (h.Exponent == 0) // Denormal (will convert to normalized) | |
{ | |
// Adjust mantissa so it's normalized (and keep track of exp adjust) | |
int e = -1; | |
uint m = h.Mantissa; | |
do | |
{ | |
e++; | |
m <<= 1; | |
} while ((m & 0x400) == 0); | |
o.Mantissa = (m & 0x3ff) << 13; | |
o.Exponent = 127 - 15 - e; | |
o.Sign = h.Sign; | |
} | |
else if (h.Exponent == 0x1f) // Inf/NaN | |
{ | |
// NOTE: It's safe to treat both with the same code path by just truncating | |
// lower Mantissa bits in NaNs (this is valid). | |
o.Mantissa = h.Mantissa << 13; | |
o.Exponent = 255; | |
o.Sign = h.Sign; | |
} | |
else // Normalized number | |
{ | |
o.Mantissa = h.Mantissa << 13; | |
o.Exponent = 127 - 15 + h.Exponent; | |
o.Sign = h.Sign; | |
} | |
} | |
return o; | |
} | |
// Conversion tables | |
static uint tab1[256], tab2[256], tab3[256]; | |
static void init_tables() | |
{ | |
FP16 f16; | |
FP32 f32; | |
for (int i=0; i < 256; i++) | |
{ | |
f16.u = i << 8; | |
f32 = half_to_float_full(f16); | |
tab1[i] = f32.u; | |
tab2[i] = 1 << 13; | |
f16.u = i; | |
f32 = half_to_float_full(f16); | |
tab3[i] = f32.u; | |
} | |
// Lower exponent end has some denormals | |
tab2[0x03] = 1 << 14; | |
tab2[0x02] = 1 << 14; | |
tab2[0x01] = 1 << 15; | |
tab2[0x83] = 1 << 14; | |
tab2[0x82] = 1 << 14; | |
tab2[0x81] = 1 << 15; | |
} | |
static FP32 half_to_float_fast(FP16 h) | |
{ | |
FP32 o; | |
if (h.u & 0x7f00) | |
o.u = tab1[h.u >> 8] + tab2[h.u >> 8] * (h.u & 0xff); | |
else | |
o.u = ((h.u & 0x8000) << 16) | tab3[h.u & 0xff]; | |
return o; | |
} | |
static FP32 half_to_float_fast2(FP16 h) | |
{ | |
static const FP32 magic = { 126 << 23 }; | |
FP32 o; | |
if (h.Exponent == 0) // Zero / Denormal | |
{ | |
o.u = magic.u + h.Mantissa; | |
o.f -= magic.f; | |
} | |
else | |
{ | |
o.Mantissa = h.Mantissa << 13; | |
if (h.Exponent == 0x1f) // Inf/NaN | |
o.Exponent = 255; | |
else | |
o.Exponent = 127 - 15 + h.Exponent; | |
} | |
o.Sign = h.Sign; | |
return o; | |
} | |
static FP32 half_to_float_fast3(FP16 h) | |
{ | |
static const FP32 magic = { 113 << 23 }; | |
static const uint shifted_exp = 0x7c00 << 13; // exponent mask after shift | |
FP32 o; | |
// mantissa+exponent | |
uint shifted = (h.u & 0x7fff) << 13; | |
uint exponent = shifted & shifted_exp; | |
// exponent cases | |
o.u = shifted; | |
if (exponent == 0) // Zero / Denormal | |
{ | |
o.u += magic.u; | |
o.f -= magic.f; | |
} | |
else if (exponent == shifted_exp) // Inf/NaN | |
o.u += (255 - 31) << 23; | |
else | |
o.u += (127 - 15) << 23; | |
o.u |= (h.u & 0x8000) << 16; // copy sign bit | |
return o; | |
} | |
static FP32 half_to_float_fast4(FP16 h) | |
{ | |
static const FP32 magic = { 113 << 23 }; | |
static const uint shifted_exp = 0x7c00 << 13; // exponent mask after shift | |
FP32 o; | |
o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits | |
uint exp = shifted_exp & o.u; // just the exponent | |
o.u += (127 - 15) << 23; // exponent adjust | |
// handle exponent special cases | |
if (exp == shifted_exp) // Inf/NaN? | |
o.u += (128 - 16) << 23; // extra exp adjust | |
else if (exp == 0) // Zero/Denormal? | |
{ | |
o.u += 1 << 23; // extra exp adjust | |
o.f -= magic.f; // renormalize | |
} | |
o.u |= (h.u & 0x8000) << 16; // sign bit | |
return o; | |
} | |
static FP32 half_to_float_fast5(FP16 h) | |
{ | |
static const FP32 magic = { (254 - 15) << 23 }; | |
static const FP32 was_infnan = { (127 + 16) << 23 }; | |
FP32 o; | |
o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits | |
o.f *= magic.f; // exponent adjust | |
if (o.f >= was_infnan.f) // make sure Inf/NaN survive | |
o.u |= 255 << 23; | |
o.u |= (h.u & 0x8000) << 16; // sign bit | |
return o; | |
} | |
static __m128 half_to_float4_SSE2(__m128i h) | |
{ | |
#define SSE_CONST4(name, val) static const __declspec(align(16)) uint name[4] = { (val), (val), (val), (val) } | |
#define CONST(name) *(const __m128i *)&name | |
SSE_CONST4(mask_nosign, 0x7fff); | |
SSE_CONST4(mask_justsign, 0x8000); | |
SSE_CONST4(mask_shifted_exp, 0x7c00 << 13); | |
SSE_CONST4(expadjust_normal, (127 - 15) << 23); | |
SSE_CONST4(expadjust_infnan, (128 - 16) << 23); | |
SSE_CONST4(expadjust_denorm, 1 << 23); | |
SSE_CONST4(magic_denorm, 113 << 23); | |
__m128i mnosign = CONST(mask_nosign); | |
__m128i expmant = _mm_and_si128(mnosign, h); | |
__m128i justsign = _mm_and_si128(h, CONST(mask_justsign)); | |
__m128i mshiftexp = CONST(mask_shifted_exp); | |
__m128i eadjust = CONST(expadjust_normal); | |
__m128i shifted = _mm_slli_epi32(expmant, 13); | |
__m128i adjusted = _mm_add_epi32(eadjust, shifted); | |
__m128i justexp = _mm_and_si128(shifted, mshiftexp); | |
__m128i zero = _mm_setzero_si128(); | |
__m128i b_isinfnan = _mm_cmpeq_epi32(mshiftexp, justexp); | |
__m128i b_isdenorm = _mm_cmpeq_epi32(zero, justexp); | |
__m128i adj_infnan = _mm_and_si128(b_isinfnan, CONST(expadjust_infnan)); | |
__m128i adjusted2 = _mm_add_epi32(adjusted, adj_infnan); | |
__m128i adj_den = CONST(expadjust_denorm); | |
__m128i den1 = _mm_add_epi32(adj_den, adjusted2); | |
__m128 den2 = _mm_sub_ps(_mm_castsi128_ps(den1), *(const __m128 *)&magic_denorm); | |
__m128 adjusted3 = _mm_and_ps(den2, _mm_castsi128_ps(b_isdenorm)); | |
__m128 adjusted4 = _mm_andnot_ps(_mm_castsi128_ps(b_isdenorm), _mm_castsi128_ps(adjusted2)); | |
__m128 adjusted5 = _mm_or_ps(adjusted3, adjusted4); | |
__m128i sign = _mm_slli_epi32(justsign, 16); | |
__m128 final = _mm_or_ps(adjusted5, _mm_castsi128_ps(sign)); | |
// ~21 SSE2 ops. | |
return final; | |
#undef SSE_CONST4 | |
#undef CONST | |
} | |
static __m128 half_to_float4b_SSE2(__m128i h) | |
{ | |
#define SSE_CONST4(name, val) static const __declspec(align(16)) uint name[4] = { (val), (val), (val), (val) } | |
#define CONST(name) *(const __m128i *)&name | |
SSE_CONST4(mask_nosign, 0x7fff); | |
SSE_CONST4(smallest_normal, 0x0400); | |
SSE_CONST4(infinity, 0x7c00); | |
SSE_CONST4(expadjust_normal, (127 - 15) << 23); | |
SSE_CONST4(magic_denorm, 113 << 23); | |
__m128i mnosign = CONST(mask_nosign); | |
__m128i eadjust = CONST(expadjust_normal); | |
__m128i smallest = CONST(smallest_normal); | |
__m128i infty = CONST(infinity); | |
__m128i expmant = _mm_and_si128(mnosign, h); | |
__m128i justsign = _mm_xor_si128(h, expmant); | |
__m128i b_notinfnan = _mm_cmpgt_epi32(infty, expmant); | |
__m128i b_isdenorm = _mm_cmpgt_epi32(smallest, expmant); | |
__m128i shifted = _mm_slli_epi32(expmant, 13); | |
__m128i adj_infnan = _mm_andnot_si128(b_notinfnan, eadjust); | |
__m128i adjusted = _mm_add_epi32(eadjust, shifted); | |
__m128i den1 = _mm_add_epi32(shifted, CONST(magic_denorm)); | |
__m128i adjusted2 = _mm_add_epi32(adjusted, adj_infnan); | |
__m128 den2 = _mm_sub_ps(_mm_castsi128_ps(den1), *(const __m128 *)&magic_denorm); | |
__m128 adjusted3 = _mm_and_ps(den2, _mm_castsi128_ps(b_isdenorm)); | |
__m128 adjusted4 = _mm_andnot_ps(_mm_castsi128_ps(b_isdenorm), _mm_castsi128_ps(adjusted2)); | |
__m128 adjusted5 = _mm_or_ps(adjusted3, adjusted4); | |
__m128i sign = _mm_slli_epi32(justsign, 16); | |
__m128 final = _mm_or_ps(adjusted5, _mm_castsi128_ps(sign)); | |
// ~19 SSE2 ops. | |
return final; | |
#undef SSE_CONST4 | |
#undef CONST | |
} | |
static __m128 half_to_float5_SSE2(__m128i h) | |
{ | |
#define SSE_CONST4(name, val) static const __declspec(align(16)) uint name[4] = { (val), (val), (val), (val) } | |
#define CONST(name) *(const __m128i *)&name | |
#define CONSTF(name) *(const __m128 *)&name | |
SSE_CONST4(mask_nosign, 0x7fff); | |
SSE_CONST4(magic, (254 - 15) << 23); | |
SSE_CONST4(was_infnan, 0x7bff); | |
SSE_CONST4(exp_infnan, 255 << 23); | |
__m128i mnosign = CONST(mask_nosign); | |
__m128i expmant = _mm_and_si128(mnosign, h); | |
__m128i justsign = _mm_xor_si128(h, expmant); | |
__m128i expmant2 = expmant; // copy (just here for counting purposes) | |
__m128i shifted = _mm_slli_epi32(expmant, 13); | |
__m128 scaled = _mm_mul_ps(_mm_castsi128_ps(shifted), *(const __m128 *)&magic); | |
__m128i b_wasinfnan = _mm_cmpgt_epi32(expmant2, CONST(was_infnan)); | |
__m128i sign = _mm_slli_epi32(justsign, 16); | |
__m128 infnanexp = _mm_and_ps(_mm_castsi128_ps(b_wasinfnan), CONSTF(exp_infnan)); | |
__m128 sign_inf = _mm_or_ps(_mm_castsi128_ps(sign), infnanexp); | |
__m128 final = _mm_or_ps(scaled, sign_inf); | |
// ~11 SSE2 ops. | |
return final; | |
#undef SSE_CONST4 | |
#undef CONST | |
#undef CONSTF | |
} | |
// make sure we don't get DCE on SSE code | |
__declspec(align(16)) float output[1024*4]; | |
int main(int argc, char **argv) | |
{ | |
FP16 h; | |
FP32 full, fast, fast2, fast3, fast4, fast5; | |
init_tables(); | |
for (int i=0; i < 0x10000; i++) | |
{ | |
h.u = i; | |
full = half_to_float_full(h); | |
fast = half_to_float_fast(h); | |
fast2 = half_to_float_fast2(h); | |
fast3 = half_to_float_fast3(h); | |
fast4 = half_to_float_fast4(h); | |
fast5 = half_to_float_fast5(h); | |
if (full.u != fast.u || full.u != fast2.u || full.u != fast3.u || full.u != fast4.u || full.u != fast5.u) | |
{ | |
printf("mismatch! val=%04x full=%08x fast=%08x fast2=%08x fast3=%08x fast4=%08x fast5=%08x\n", i, full.u, fast.u, fast2.u, fast3.u, fast4.u, fast5.u); | |
return 1; | |
} | |
} | |
for (int i=0; i < 0x10000; i += 4) | |
{ | |
uint ref[4]; | |
uint ssein[4], sseout[4]; | |
for (int j=0; j < 4; j++) | |
{ | |
ssein[j] = i + j; | |
h.u = i + j; | |
full = half_to_float_full(h); | |
ref[j] = full.u; | |
} | |
__m128i in = _mm_loadu_si128((const __m128i *)ssein); | |
__m128 out = half_to_float4b_SSE2(in); | |
_mm_storeu_ps((float *)sseout, out); | |
for (int j=0; j < 4; j++) | |
{ | |
if (sseout[j] != ref[j]) | |
{ | |
printf("mismatch! val=%04x full=%08x fast4SSE2=%08x\n", i+j, ref[j], sseout[j]); | |
return 1; | |
} | |
} | |
} | |
uint best = ~0u; | |
int start = 0, end = 0x10000; | |
for (int runs=0; runs < 15000; runs++) | |
{ | |
__m128i vals = _mm_set_epi32(start + 3, start + 2, start + 1, start + 0); | |
__m128i incr = _mm_set1_epi32(4); | |
uint tstart = (uint) __rdtsc(); | |
for (int i=start; i < end; i += 4) | |
{ | |
__m128 out = half_to_float4b_SSE2(vals); | |
_mm_store_ps(&output[i & 1023], out); | |
vals = _mm_add_epi32(vals, incr); | |
} | |
uint time = (uint) __rdtsc() - tstart; | |
if (time < best) | |
best = time; | |
} | |
printf("best: %d cycles = %.2f / vec\n", best, 4.0f * best / (end - start)); | |
printf("all ok.\n"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment