Skip to content

Instantly share code, notes, and snippets.

@mattip
Last active November 20, 2024 06:10
Show Gist options
  • Save mattip/9cd33dd43155bd71e981f6e985778c06 to your computer and use it in GitHub Desktop.
Save mattip/9cd33dd43155bd71e981f6e985778c06 to your computer and use it in GitHub Desktop.
Exploring fast and accurate floating point summation

From numpy/numpy#22956

Here is some code implementing a kahan summation. It is not fast

import cffi
import numpy as np
import time

ffi = cffi.FFI()
ffi.cdef("""
    float accurate_sum(float *values, int n);
    float inaccurate_sum(float *values, int n);
""")

ffi.set_source("test_sums", """

static __inline__
void kahan_sum(float *sum, float *c, float y) {
    float t;
    y -= c[0];
    t = sum[0] + y;
    c[0] = (t - sum[0]) - y;
    sum[0] = t;
}

float accurate_sum(float *values, int n) {
    float sum = 0.0;
    float c = 0.0;
    for (int i=0; i<n; i++) {
        kahan_sum(&sum, &c, values[i]);    
    }
    return sum;
}

float inaccurate_sum(float *values, int n){
    float sum = 0.0;
    for (int i=0; i<n; i++)
        sum += values[i];
    return sum;
}
""")

ffi.compile()

from test_sums import ffi, lib

n = 20_000_000
a = np.ones(n, dtype=np.float32)
m = ffi.cast("float*", a.ctypes.data)
start = time.time()
for i in range(10):
    sum = lib.accurate_sum(m, n)
stop = time.time()
print(f"accurate {sum}, took {stop - start:.2f}")
start = time.time()
for i in range(10):
    sum = lib.inaccurate_sum(m, n)
stop = time.time()
print(f"inaccurate {sum}, took {stop - start:.2f}")

On my machine the naive implementation is 4.9x slower:

$ python kahan.py 
accurate 20000000.0, took 0.73
inaccurate 16777216.0, took 0.14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment