Skip to content

Instantly share code, notes, and snippets.

@going-digital
Last active May 9, 2024 21:07
Show Gist options
  • Save going-digital/02e46c44d89237c07bc99cd440ebfa43 to your computer and use it in GitHub Desktop.
Save going-digital/02e46c44d89237c07bc99cd440ebfa43 to your computer and use it in GitHub Desktop.
WebAssembly native sin, log and exp functions optimised for code size.
;; Native implementations of sin, log and exp functions.
;; sintau: 41 bytes code, 34 bytes shared code, 24 bytes data
;; exp2: 25 bytes code, 34 bytes shared code, 20 bytes data
;; log2: 37 bytes code, 34 bytes shared code, 24 bytes data
;; Total 137 bytes code, 68 bytes data
;; Wasm-opt -Oz tries to optimise out $half by converting to f32.consts, but that actually takes up more space, not less.
;; Polynomial coefficients calculated by accompanying python script.
;; call $evalpoly parameters will need to be manually changed for different length polynomials.
;; Note that these implementations are undefined for NaNs, Infs, denormals or other invalid inputs.
(module
(export "sintau" (func $sintau))
;;
;; sintau(x) returns sin(2*pi*x)
;;
(func $sintau
(param $x f32)
(result f32)
(local $x1 f32)
(local $half f32)
(f32.copysign
(call $evalpoly
;; Reduce to 0..0.25 by folding about 0.25
(f32.min
(f32.sub
(tee_local $half (f32.const 0.5))
;; Reduce to 0..0.5 by folding about 0.5
(tee_local $x1 (f32.abs (f32.sub
;; Reduce $x1 to range 0..1
(tee_local $x
(f32.sub (get_local $x) (f32.floor (get_local $x)))
)
(get_local $half)
)))
)
(local.get $x1)
)
(i32.const 0) (i32.const 24)
)
;; Sign of result
(f32.sub
(get_local $half)
(local.get $x)
)
)
)
(export "exp2" (func $exp2))
;;
;; exp2(x) returns pow(2,x)
;;
(func $exp2
(param $x f32)
(result f32)
(f32.reinterpret/i32
(i32.add
(i32.reinterpret/f32
(call $evalpoly
(f32.sub
(get_local $x)
(tee_local $x (f32.floor (get_local $x)))
)
(i32.const 24) (i32.const 44)
)
)
(i32.shl
(i32.trunc_s/f32 (get_local $x) )
(i32.const 23)
)
)
)
)
(export "log2" (func $log2))
;;
;; log2(x) returns log base 2 of x = ln(x) / ln(2)
;;
(func $log2
(param $x f32)
(result f32)
(local $xi i32)
(f32.add
;; Extract exponent of $x
(f32.convert_s/i32
(i32.sub
(i32.shr_u
(tee_local $xi (i32.reinterpret/f32 (get_local $x)))
(i32.const 23))
(i32.const 127)
)
)
;; Calculate logarithm of mantissa with a mapping
(call $evalpoly
;; First parameter is (mantissa-1)
(f32.div
(f32.convert_u/i32
(i32.shl
(get_local $xi)
(i32.const 9)
)
)
(f32.const 4294967296)
)
(i32.const 44) (i32.const 68)
)
)
)
;; Internal function to evaluate polynomials
;; Uses coefficients calculated by calc_coef.py. Start and end parameters are addresses in data table.
;;
(func $evalpoly (param $x f32) (param $start i32) (param $end i32) (result f32)
(local $result f32)
(loop $loop
(set_local $result
(f32.add
(f32.mul (get_local $result) (get_local $x))
(f32.load (get_local $start) )
)
)
(br_if $loop
(i32.sub
(tee_local $start
(i32.add (get_local $start) (i32.const 4))
)
(get_local $end))
)
)
(get_local $result)
)
(memory 68 68)
(data (i32.const 0)
;; sintau polynomial coefficients
"\3f\c7\61\42" "\d9\e0\13\41" "\4b\aa\2a\c2" "\73\b2\a6\3d" "\40\01\c9\40" "\7e\95\d0\36"
;; exp2 polynomial coefficients
"\6f\f9\5f\3c" "\90\f2\53\3d" "\22\67\77\3e" "\ac\66\31\3f" "\1d\00\80\3f"
;; log2 polynomial coefficients
"\f7\25\30\3d" "\03\fd\3f\be" "\17\a6\d1\3e" "\4c\dc\34\bf" "\d3\82\b8\3f" "\fc\88\8a\37"
)
)
import struct
# Calculate best fit polynomials for function mappings.
# Chebyshev fitting is used to minimise overall error
# Note that log2 input is the bit representation of the IEEE754 mantissa, which has 1 pre-subtracted which is why x+1 is used.
from mpmath import *
mp.dps = 20
mp.pretty = True
functions = [
{'name': 'sintau', 'function': lambda x: sin(x*2*pi), 'xrange': [0, 0.25], 'max_error': 1e-4 },
{'name': 'exp2', 'function': lambda x: power(2,x), 'xrange': [0, 1], 'max_error': 1e-4 },
{'name': 'log2', 'function': lambda x: log(x+1) / log(2), 'xrange': [0, 1], 'max_error': 1e-4 },
]
for f in functions:
for deg in range(1, 10):
poly, err = chebyfit(f['function'], f['xrange'], deg, error=True)
if err < f['max_error']: break
data = struct.pack("<{}f".format(len(poly)), *poly)
print(f['name'], data.hex())
nprint(poly)
#nprint(err, 12)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment