Skip to content

Instantly share code, notes, and snippets.

@djinn
Last active August 28, 2025 05:55
Show Gist options
  • Save djinn/004a9e94209e4b5563f5f70c5acc7e22 to your computer and use it in GitHub Desktop.
Save djinn/004a9e94209e4b5563f5f70c5acc7e22 to your computer and use it in GitHub Desktop.
Comparing Deepseek UE8M0 and Nvidia E4M3 CUDA FP8
#include <iostream>
#include <chrono>
#include <vector>
#include <random>
#include <iomanip>
#include <cmath>
#include <cassert>
// E4M3 FP8 Format (NVIDIA standard)
// 1 bit sign, 4 bits exponent, 3 bits mantissa
struct E4M3 {
uint8_t bits;
E4M3() : bits(0) {}
E4M3(uint8_t b) : bits(b) {}
E4M3(float f) : bits(float_to_e4m3(f)) {}
static uint8_t float_to_e4m3(float f) {
if (f == 0.0f) return 0;
if (std::isnan(f) || std::isinf(f)) return 0x7F; // NaN/Inf
uint32_t bits = *reinterpret_cast<uint32_t*>(&f);
uint32_t sign = (bits >> 31) & 1;
int32_t exp = ((bits >> 23) & 0xFF) - 127; // Remove IEEE bias
uint32_t mantissa = bits & 0x7FFFFF;
// E4M3 bias is 7 (2^(4-1) - 1)
exp += 7;
// Clamp exponent to E4M3 range [0, 15]
if (exp <= 0) return sign << 7; // Zero/underflow
if (exp >= 15) return (sign << 7) | 0x7F; // Overflow to max
// Round mantissa to 3 bits
uint32_t rounded_mantissa = (mantissa + (1 << 19)) >> 20;
if (rounded_mantissa >= 8) {
exp++;
rounded_mantissa = 0;
}
return (sign << 7) | (exp << 3) | rounded_mantissa;
}
float to_float() const {
if (bits == 0) return 0.0f;
uint32_t sign = (bits >> 7) & 1;
uint32_t exp = (bits >> 3) & 0xF;
uint32_t mantissa = bits & 0x7;
if (exp == 0) return 0.0f; // Zero
if (exp == 15 && mantissa == 7) return std::numeric_limits<float>::quiet_NaN();
// Convert to IEEE 754
int32_t ieee_exp = exp - 7 + 127; // Remove E4M3 bias, add IEEE bias
uint32_t ieee_mantissa = mantissa << 20;
uint32_t ieee_bits = (sign << 31) | (ieee_exp << 23) | ieee_mantissa;
return *reinterpret_cast<float*>(&ieee_bits);
}
E4M3 operator+(const E4M3& other) const {
return E4M3(this->to_float() + other.to_float());
}
E4M3 operator-(const E4M3& other) const {
return E4M3(this->to_float() - other.to_float());
}
E4M3 operator*(const E4M3& other) const {
return E4M3(this->to_float() * other.to_float());
}
E4M3 operator/(const E4M3& other) const {
return E4M3(this->to_float() / other.to_float());
}
};
// UE8M0 FP8 Format (DeepSeek's format)
// 0 bits sign (unsigned), 8 bits exponent, 0 bits mantissa
struct UE8M0 {
uint8_t bits;
UE8M0() : bits(0) {}
UE8M0(uint8_t b) : bits(b) {}
UE8M0(float f) : bits(float_to_ue8m0(f)) {}
static uint8_t float_to_ue8m0(float f) {
if (f <= 0.0f) return 0; // Only positive values
if (std::isnan(f) || std::isinf(f)) return 0xFF;
uint32_t bits = *reinterpret_cast<uint32_t*>(&f);
int32_t exp = ((bits >> 23) & 0xFF) - 127; // Remove IEEE bias
// UE8M0 bias is 127 (to match a wide range)
exp += 127;
// Clamp to 8-bit range
if (exp <= 0) return 0;
if (exp >= 255) return 0xFF;
return static_cast<uint8_t>(exp);
}
float to_float() const {
if (bits == 0) return 0.0f;
if (bits == 0xFF) return std::numeric_limits<float>::infinity();
// Convert back to IEEE 754
// Since mantissa is 0, this represents powers of 2
int32_t ieee_exp = bits - 127 + 127; // Remove UE8M0 bias, add IEEE bias
uint32_t ieee_bits = (ieee_exp << 23); // No mantissa, no sign
return *reinterpret_cast<float*>(&ieee_bits);
}
UE8M0 operator+(const UE8M0& other) const {
return UE8M0(this->to_float() + other.to_float());
}
UE8M0 operator-(const UE8M0& other) const {
float result = this->to_float() - other.to_float();
return UE8M0(std::max(0.0f, result)); // Clamp to non-negative
}
UE8M0 operator*(const UE8M0& other) const {
return UE8M0(this->to_float() * other.to_float());
}
UE8M0 operator/(const UE8M0& other) const {
return UE8M0(this->to_float() / other.to_float());
}
};
// Prevent compiler optimization
volatile uint8_t optimization_barrier = 0;
// Benchmark function template
template<typename T>
double benchmark_operation(const std::vector<T>& a, const std::vector<T>& b,
std::vector<T>& result, char op) {
auto start = std::chrono::high_resolution_clock::now();
size_t n = a.size();
switch(op) {
case '+':
for (size_t i = 0; i < n; ++i) {
result[i] = a[i] + b[i];
// Prevent optimization
optimization_barrier = result[i].bits & 1;
}
break;
case '-':
for (size_t i = 0; i < n; ++i) {
result[i] = a[i] - b[i];
optimization_barrier = result[i].bits & 1;
}
break;
case '*':
for (size_t i = 0; i < n; ++i) {
result[i] = a[i] * b[i];
optimization_barrier = result[i].bits & 1;
}
break;
case '/':
for (size_t i = 0; i < n; ++i) {
result[i] = a[i] / b[i];
optimization_barrier = result[i].bits & 1;
}
break;
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
return duration.count();
}
int main() {
const size_t N = 1000000; // 1M operations
const int ITERATIONS = 10;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0.1f, 100.0f);
// Generate test data
std::vector<float> float_a(N), float_b(N);
std::vector<E4M3> e4m3_a(N), e4m3_b(N), e4m3_result(N);
std::vector<UE8M0> ue8m0_a(N), ue8m0_b(N), ue8m0_result(N);
for (size_t i = 0; i < N; ++i) {
float_a[i] = dis(gen);
float_b[i] = dis(gen);
e4m3_a[i] = E4M3(float_a[i]);
e4m3_b[i] = E4M3(float_b[i]);
ue8m0_a[i] = UE8M0(float_a[i]);
ue8m0_b[i] = UE8M0(float_b[i]);
}
std::cout << "FP8 Format Arithmetic Benchmark\n";
std::cout << "================================\n";
std::cout << "Operations: " << N << " per test\n";
std::cout << "Iterations: " << ITERATIONS << "\n\n";
char operations[] = {'+', '-', '*', '/'};
const char* op_names[] = {"Addition", "Subtraction", "Multiplication", "Division"};
for (int op_idx = 0; op_idx < 4; ++op_idx) {
char op = operations[op_idx];
std::cout << op_names[op_idx] << ":\n";
std::cout << std::setw(12) << "Format" << std::setw(15) << "Avg Time (μs)"
<< std::setw(15) << "Ops/sec" << std::setw(20) << "Relative Speed\n";
std::cout << std::string(60, '-') << "\n";
// Benchmark E4M3
double e4m3_total = 0;
for (int i = 0; i < ITERATIONS; ++i) {
e4m3_total += benchmark_operation(e4m3_a, e4m3_b, e4m3_result, op);
}
double e4m3_avg = e4m3_total / ITERATIONS;
double e4m3_ops_per_sec = (N * 1000000.0) / e4m3_avg;
// Benchmark UE8M0
double ue8m0_total = 0;
for (int i = 0; i < ITERATIONS; ++i) {
ue8m0_total += benchmark_operation(ue8m0_a, ue8m0_b, ue8m0_result, op);
}
double ue8m0_avg = ue8m0_total / ITERATIONS;
double ue8m0_ops_per_sec = (N * 1000000.0) / ue8m0_avg;
// Calculate relative speed (handle division by zero)
double speed_ratio = (e4m3_avg > 0.001) ? ue8m0_avg / e4m3_avg : 1.0;
std::cout << std::setw(12) << "E4M3" << std::setw(15) << std::fixed << std::setprecision(1) << e4m3_avg
<< std::setw(15) << std::scientific << std::setprecision(2) << e4m3_ops_per_sec
<< std::setw(20) << std::fixed << std::setprecision(2) << "1.00x\n";
std::cout << std::setw(12) << "UE8M0" << std::setw(15) << std::fixed << std::setprecision(1) << ue8m0_avg
<< std::setw(15) << std::scientific << std::setprecision(2) << ue8m0_ops_per_sec
<< std::setw(20) << std::fixed << std::setprecision(2) << speed_ratio << "x\n\n";
}
// Print the optimization barrier to prevent compiler from removing it
if (optimization_barrier == 255) std::cout << ""; // Never executes but prevents optimization
// Format analysis
std::cout << "\nFormat Analysis:\n";
std::cout << "================\n";
std::cout << "E4M3 (NVIDIA Standard):\n";
std::cout << " - 1 bit sign, 4 bits exponent, 3 bits mantissa\n";
std::cout << " - Range: ~±448 (largest finite value)\n";
std::cout << " - Precision: 8 discrete levels between powers of 2\n\n";
std::cout << "UE8M0 (DeepSeek Custom):\n";
std::cout << " - 0 bits sign (unsigned), 8 bits exponent, 0 bits mantissa\n";
std::cout << " - Range: 0 to ~10^38 (huge dynamic range)\n";
std::cout << " - Precision: Only powers of 2 (very coarse)\n\n";
// Show bit patterns for some example values
std::cout << "Bit Pattern Examples:\n";
std::cout << "====================\n";
std::cout << std::setw(10) << "Value" << std::setw(15) << "E4M3 (binary)"
<< std::setw(15) << "UE8M0 (binary)" << std::setw(12) << "E4M3 Dec" << std::setw(12) << "UE8M0 Dec\n";
std::cout << std::string(65, '-') << "\n";
float test_vals[] = {1.0f, 2.0f, 4.0f, 8.0f, 16.0f, 0.5f, 0.25f};
for (float val : test_vals) {
E4M3 e4m3_val(val);
UE8M0 ue8m0_val(val);
// Convert to binary string
std::string e4m3_binary = "";
std::string ue8m0_binary = "";
for (int i = 7; i >= 0; i--) {
e4m3_binary += ((e4m3_val.bits >> i) & 1) ? '1' : '0';
ue8m0_binary += ((ue8m0_val.bits >> i) & 1) ? '1' : '0';
}
std::cout << std::setw(10) << std::fixed << std::setprecision(2) << val
<< std::setw(15) << e4m3_binary
<< std::setw(15) << ue8m0_binary
<< std::setw(12) << (int)e4m3_val.bits
<< std::setw(12) << (int)ue8m0_val.bits << "\n";
}
// Accuracy comparison
std::cout << "Accuracy Comparison (first 10 values):\n";
std::cout << "=====================================\n";
std::cout << std::setw(8) << "Original" << std::setw(12) << "E4M3" << std::setw(12) << "UE8M0"
<< std::setw(12) << "E4M3 Error" << std::setw(12) << "UE8M0 Error\n";
std::cout << std::string(55, '-') << "\n";
for (int i = 0; i < 10; ++i) {
float orig = float_a[i] + float_b[i];
float e4m3_val = (e4m3_a[i] + e4m3_b[i]).to_float();
float ue8m0_val = (ue8m0_a[i] + ue8m0_b[i]).to_float();
std::cout << std::setw(8) << std::fixed << std::setprecision(3) << orig
<< std::setw(12) << std::setprecision(3) << e4m3_val
<< std::setw(12) << std::setprecision(3) << ue8m0_val
<< std::setw(12) << std::setprecision(3) << std::abs(orig - e4m3_val)
<< std::setw(12) << std::setprecision(3) << std::abs(orig - ue8m0_val) << "\n";
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment