Last active
August 28, 2025 05:55
-
-
Save djinn/004a9e94209e4b5563f5f70c5acc7e22 to your computer and use it in GitHub Desktop.
Comparing Deepseek UE8M0 and Nvidia E4M3 CUDA FP8
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <chrono> | |
#include <vector> | |
#include <random> | |
#include <iomanip> | |
#include <cmath> | |
#include <cassert> | |
// E4M3 FP8 Format (NVIDIA standard) | |
// 1 bit sign, 4 bits exponent, 3 bits mantissa | |
struct E4M3 { | |
uint8_t bits; | |
E4M3() : bits(0) {} | |
E4M3(uint8_t b) : bits(b) {} | |
E4M3(float f) : bits(float_to_e4m3(f)) {} | |
static uint8_t float_to_e4m3(float f) { | |
if (f == 0.0f) return 0; | |
if (std::isnan(f) || std::isinf(f)) return 0x7F; // NaN/Inf | |
uint32_t bits = *reinterpret_cast<uint32_t*>(&f); | |
uint32_t sign = (bits >> 31) & 1; | |
int32_t exp = ((bits >> 23) & 0xFF) - 127; // Remove IEEE bias | |
uint32_t mantissa = bits & 0x7FFFFF; | |
// E4M3 bias is 7 (2^(4-1) - 1) | |
exp += 7; | |
// Clamp exponent to E4M3 range [0, 15] | |
if (exp <= 0) return sign << 7; // Zero/underflow | |
if (exp >= 15) return (sign << 7) | 0x7F; // Overflow to max | |
// Round mantissa to 3 bits | |
uint32_t rounded_mantissa = (mantissa + (1 << 19)) >> 20; | |
if (rounded_mantissa >= 8) { | |
exp++; | |
rounded_mantissa = 0; | |
} | |
return (sign << 7) | (exp << 3) | rounded_mantissa; | |
} | |
float to_float() const { | |
if (bits == 0) return 0.0f; | |
uint32_t sign = (bits >> 7) & 1; | |
uint32_t exp = (bits >> 3) & 0xF; | |
uint32_t mantissa = bits & 0x7; | |
if (exp == 0) return 0.0f; // Zero | |
if (exp == 15 && mantissa == 7) return std::numeric_limits<float>::quiet_NaN(); | |
// Convert to IEEE 754 | |
int32_t ieee_exp = exp - 7 + 127; // Remove E4M3 bias, add IEEE bias | |
uint32_t ieee_mantissa = mantissa << 20; | |
uint32_t ieee_bits = (sign << 31) | (ieee_exp << 23) | ieee_mantissa; | |
return *reinterpret_cast<float*>(&ieee_bits); | |
} | |
E4M3 operator+(const E4M3& other) const { | |
return E4M3(this->to_float() + other.to_float()); | |
} | |
E4M3 operator-(const E4M3& other) const { | |
return E4M3(this->to_float() - other.to_float()); | |
} | |
E4M3 operator*(const E4M3& other) const { | |
return E4M3(this->to_float() * other.to_float()); | |
} | |
E4M3 operator/(const E4M3& other) const { | |
return E4M3(this->to_float() / other.to_float()); | |
} | |
}; | |
// UE8M0 FP8 Format (DeepSeek's format) | |
// 0 bits sign (unsigned), 8 bits exponent, 0 bits mantissa | |
struct UE8M0 { | |
uint8_t bits; | |
UE8M0() : bits(0) {} | |
UE8M0(uint8_t b) : bits(b) {} | |
UE8M0(float f) : bits(float_to_ue8m0(f)) {} | |
static uint8_t float_to_ue8m0(float f) { | |
if (f <= 0.0f) return 0; // Only positive values | |
if (std::isnan(f) || std::isinf(f)) return 0xFF; | |
uint32_t bits = *reinterpret_cast<uint32_t*>(&f); | |
int32_t exp = ((bits >> 23) & 0xFF) - 127; // Remove IEEE bias | |
// UE8M0 bias is 127 (to match a wide range) | |
exp += 127; | |
// Clamp to 8-bit range | |
if (exp <= 0) return 0; | |
if (exp >= 255) return 0xFF; | |
return static_cast<uint8_t>(exp); | |
} | |
float to_float() const { | |
if (bits == 0) return 0.0f; | |
if (bits == 0xFF) return std::numeric_limits<float>::infinity(); | |
// Convert back to IEEE 754 | |
// Since mantissa is 0, this represents powers of 2 | |
int32_t ieee_exp = bits - 127 + 127; // Remove UE8M0 bias, add IEEE bias | |
uint32_t ieee_bits = (ieee_exp << 23); // No mantissa, no sign | |
return *reinterpret_cast<float*>(&ieee_bits); | |
} | |
UE8M0 operator+(const UE8M0& other) const { | |
return UE8M0(this->to_float() + other.to_float()); | |
} | |
UE8M0 operator-(const UE8M0& other) const { | |
float result = this->to_float() - other.to_float(); | |
return UE8M0(std::max(0.0f, result)); // Clamp to non-negative | |
} | |
UE8M0 operator*(const UE8M0& other) const { | |
return UE8M0(this->to_float() * other.to_float()); | |
} | |
UE8M0 operator/(const UE8M0& other) const { | |
return UE8M0(this->to_float() / other.to_float()); | |
} | |
}; | |
// Prevent compiler optimization | |
volatile uint8_t optimization_barrier = 0; | |
// Benchmark function template | |
template<typename T> | |
double benchmark_operation(const std::vector<T>& a, const std::vector<T>& b, | |
std::vector<T>& result, char op) { | |
auto start = std::chrono::high_resolution_clock::now(); | |
size_t n = a.size(); | |
switch(op) { | |
case '+': | |
for (size_t i = 0; i < n; ++i) { | |
result[i] = a[i] + b[i]; | |
// Prevent optimization | |
optimization_barrier = result[i].bits & 1; | |
} | |
break; | |
case '-': | |
for (size_t i = 0; i < n; ++i) { | |
result[i] = a[i] - b[i]; | |
optimization_barrier = result[i].bits & 1; | |
} | |
break; | |
case '*': | |
for (size_t i = 0; i < n; ++i) { | |
result[i] = a[i] * b[i]; | |
optimization_barrier = result[i].bits & 1; | |
} | |
break; | |
case '/': | |
for (size_t i = 0; i < n; ++i) { | |
result[i] = a[i] / b[i]; | |
optimization_barrier = result[i].bits & 1; | |
} | |
break; | |
} | |
auto end = std::chrono::high_resolution_clock::now(); | |
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start); | |
return duration.count(); | |
} | |
int main() { | |
const size_t N = 1000000; // 1M operations | |
const int ITERATIONS = 10; | |
std::random_device rd; | |
std::mt19937 gen(rd()); | |
std::uniform_real_distribution<float> dis(0.1f, 100.0f); | |
// Generate test data | |
std::vector<float> float_a(N), float_b(N); | |
std::vector<E4M3> e4m3_a(N), e4m3_b(N), e4m3_result(N); | |
std::vector<UE8M0> ue8m0_a(N), ue8m0_b(N), ue8m0_result(N); | |
for (size_t i = 0; i < N; ++i) { | |
float_a[i] = dis(gen); | |
float_b[i] = dis(gen); | |
e4m3_a[i] = E4M3(float_a[i]); | |
e4m3_b[i] = E4M3(float_b[i]); | |
ue8m0_a[i] = UE8M0(float_a[i]); | |
ue8m0_b[i] = UE8M0(float_b[i]); | |
} | |
std::cout << "FP8 Format Arithmetic Benchmark\n"; | |
std::cout << "================================\n"; | |
std::cout << "Operations: " << N << " per test\n"; | |
std::cout << "Iterations: " << ITERATIONS << "\n\n"; | |
char operations[] = {'+', '-', '*', '/'}; | |
const char* op_names[] = {"Addition", "Subtraction", "Multiplication", "Division"}; | |
for (int op_idx = 0; op_idx < 4; ++op_idx) { | |
char op = operations[op_idx]; | |
std::cout << op_names[op_idx] << ":\n"; | |
std::cout << std::setw(12) << "Format" << std::setw(15) << "Avg Time (μs)" | |
<< std::setw(15) << "Ops/sec" << std::setw(20) << "Relative Speed\n"; | |
std::cout << std::string(60, '-') << "\n"; | |
// Benchmark E4M3 | |
double e4m3_total = 0; | |
for (int i = 0; i < ITERATIONS; ++i) { | |
e4m3_total += benchmark_operation(e4m3_a, e4m3_b, e4m3_result, op); | |
} | |
double e4m3_avg = e4m3_total / ITERATIONS; | |
double e4m3_ops_per_sec = (N * 1000000.0) / e4m3_avg; | |
// Benchmark UE8M0 | |
double ue8m0_total = 0; | |
for (int i = 0; i < ITERATIONS; ++i) { | |
ue8m0_total += benchmark_operation(ue8m0_a, ue8m0_b, ue8m0_result, op); | |
} | |
double ue8m0_avg = ue8m0_total / ITERATIONS; | |
double ue8m0_ops_per_sec = (N * 1000000.0) / ue8m0_avg; | |
// Calculate relative speed (handle division by zero) | |
double speed_ratio = (e4m3_avg > 0.001) ? ue8m0_avg / e4m3_avg : 1.0; | |
std::cout << std::setw(12) << "E4M3" << std::setw(15) << std::fixed << std::setprecision(1) << e4m3_avg | |
<< std::setw(15) << std::scientific << std::setprecision(2) << e4m3_ops_per_sec | |
<< std::setw(20) << std::fixed << std::setprecision(2) << "1.00x\n"; | |
std::cout << std::setw(12) << "UE8M0" << std::setw(15) << std::fixed << std::setprecision(1) << ue8m0_avg | |
<< std::setw(15) << std::scientific << std::setprecision(2) << ue8m0_ops_per_sec | |
<< std::setw(20) << std::fixed << std::setprecision(2) << speed_ratio << "x\n\n"; | |
} | |
// Print the optimization barrier to prevent compiler from removing it | |
if (optimization_barrier == 255) std::cout << ""; // Never executes but prevents optimization | |
// Format analysis | |
std::cout << "\nFormat Analysis:\n"; | |
std::cout << "================\n"; | |
std::cout << "E4M3 (NVIDIA Standard):\n"; | |
std::cout << " - 1 bit sign, 4 bits exponent, 3 bits mantissa\n"; | |
std::cout << " - Range: ~±448 (largest finite value)\n"; | |
std::cout << " - Precision: 8 discrete levels between powers of 2\n\n"; | |
std::cout << "UE8M0 (DeepSeek Custom):\n"; | |
std::cout << " - 0 bits sign (unsigned), 8 bits exponent, 0 bits mantissa\n"; | |
std::cout << " - Range: 0 to ~10^38 (huge dynamic range)\n"; | |
std::cout << " - Precision: Only powers of 2 (very coarse)\n\n"; | |
// Show bit patterns for some example values | |
std::cout << "Bit Pattern Examples:\n"; | |
std::cout << "====================\n"; | |
std::cout << std::setw(10) << "Value" << std::setw(15) << "E4M3 (binary)" | |
<< std::setw(15) << "UE8M0 (binary)" << std::setw(12) << "E4M3 Dec" << std::setw(12) << "UE8M0 Dec\n"; | |
std::cout << std::string(65, '-') << "\n"; | |
float test_vals[] = {1.0f, 2.0f, 4.0f, 8.0f, 16.0f, 0.5f, 0.25f}; | |
for (float val : test_vals) { | |
E4M3 e4m3_val(val); | |
UE8M0 ue8m0_val(val); | |
// Convert to binary string | |
std::string e4m3_binary = ""; | |
std::string ue8m0_binary = ""; | |
for (int i = 7; i >= 0; i--) { | |
e4m3_binary += ((e4m3_val.bits >> i) & 1) ? '1' : '0'; | |
ue8m0_binary += ((ue8m0_val.bits >> i) & 1) ? '1' : '0'; | |
} | |
std::cout << std::setw(10) << std::fixed << std::setprecision(2) << val | |
<< std::setw(15) << e4m3_binary | |
<< std::setw(15) << ue8m0_binary | |
<< std::setw(12) << (int)e4m3_val.bits | |
<< std::setw(12) << (int)ue8m0_val.bits << "\n"; | |
} | |
// Accuracy comparison | |
std::cout << "Accuracy Comparison (first 10 values):\n"; | |
std::cout << "=====================================\n"; | |
std::cout << std::setw(8) << "Original" << std::setw(12) << "E4M3" << std::setw(12) << "UE8M0" | |
<< std::setw(12) << "E4M3 Error" << std::setw(12) << "UE8M0 Error\n"; | |
std::cout << std::string(55, '-') << "\n"; | |
for (int i = 0; i < 10; ++i) { | |
float orig = float_a[i] + float_b[i]; | |
float e4m3_val = (e4m3_a[i] + e4m3_b[i]).to_float(); | |
float ue8m0_val = (ue8m0_a[i] + ue8m0_b[i]).to_float(); | |
std::cout << std::setw(8) << std::fixed << std::setprecision(3) << orig | |
<< std::setw(12) << std::setprecision(3) << e4m3_val | |
<< std::setw(12) << std::setprecision(3) << ue8m0_val | |
<< std::setw(12) << std::setprecision(3) << std::abs(orig - e4m3_val) | |
<< std::setw(12) << std::setprecision(3) << std::abs(orig - ue8m0_val) << "\n"; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment