Created
January 16, 2025 05:31
-
-
Save raphlinus/5f4e9feb85fd79bafc72da744571ec0e to your computer and use it in GitHub Desktop.
Snapshot of SIMD flatten implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2025 the Fearless_SIMD Authors | |
// SPDX-License-Identifier: Apache-2.0 OR MIT | |
//! Example of fast flattening cubic Beziers. | |
// Arguably we should just take a kurbo dep (or do development | |
// in another crate), but here we can | |
use core::f32; | |
use fearless_simd::{f32x4, f32x8, mask32x4, x86_64::Avx2, Level, Simd, SimdInto}; | |
// This is similar to kurbo point but repr(C) | |
#[derive(Clone, Copy)] | |
#[repr(C)] | |
struct Point { | |
x: f64, | |
y: f64, | |
} | |
#[derive(Clone, Copy, Debug, Default)] | |
#[repr(C)] | |
struct Point32 { | |
x: f32, | |
y: f32, | |
} | |
struct Vec2 { | |
x: f64, | |
y: f64, | |
} | |
// Again similar to kurbo but repr(C) | |
#[derive(Clone, Copy)] | |
#[repr(C)] | |
struct CubicBez { | |
p0: Point, | |
p1: Point, | |
p2: Point, | |
p3: Point, | |
} | |
const MAX_QUADS: usize = 16; | |
#[derive(Default, Debug)] | |
struct FlattenCtx { | |
// The +4 is to encourage alignment; might be better to be explicit | |
even_pts: [Point32; MAX_QUADS + 4], | |
odd_pts: [Point32; MAX_QUADS], | |
a0: [f32; MAX_QUADS], | |
da: [f32; MAX_QUADS], | |
u0: [f32; MAX_QUADS], | |
uscale: [f32; MAX_QUADS], | |
val: [f32; MAX_QUADS], | |
n_quads: usize, | |
} | |
impl Point { | |
pub const fn new(x: f64, y: f64) -> Self { | |
Point { x, y } | |
} | |
pub fn to_vec2(self) -> Vec2 { | |
Vec2 { | |
x: self.x, | |
y: self.y, | |
} | |
} | |
pub fn to_point32(self) -> Point32 { | |
Point32 { | |
x: self.x as f32, | |
y: self.y as f32, | |
} | |
} | |
} | |
impl Vec2 { | |
pub fn to_point(self) -> Point { | |
Point::new(self.x, self.y) | |
} | |
pub fn hypot2(self) -> f64 { | |
self.x * self.x + self.y * self.y | |
} | |
} | |
impl std::ops::Mul<f64> for Vec2 { | |
type Output = Self; | |
fn mul(self, rhs: f64) -> Self::Output { | |
Vec2 { | |
x: self.x * rhs, | |
y: self.y * rhs, | |
} | |
} | |
} | |
impl std::ops::Add for Vec2 { | |
type Output = Self; | |
fn add(self, rhs: Vec2) -> Self::Output { | |
Vec2 { | |
x: self.x + rhs.x, | |
y: self.y + rhs.y, | |
} | |
} | |
} | |
impl std::ops::Sub for Vec2 { | |
type Output = Self; | |
fn sub(self, rhs: Vec2) -> Self::Output { | |
Vec2 { | |
x: self.x - rhs.x, | |
y: self.y - rhs.y, | |
} | |
} | |
} | |
impl CubicBez { | |
#[inline] | |
fn eval(&self, t: f64) -> Point { | |
let mt = 1.0 - t; | |
let v = self.p0.to_vec2() * (mt * mt * mt) | |
+ (self.p1.to_vec2() * (mt * mt * 3.0) | |
+ (self.p2.to_vec2() * (mt * 3.0) + self.p3.to_vec2() * t) * t) | |
* t; | |
v.to_point() | |
} | |
} | |
#[inline(never)] | |
fn eval_cubics(c: &CubicBez, n: usize, result: &mut FlattenCtx) { | |
result.n_quads = n; | |
let dt = 1.0 / n as f64; | |
for i in 0..n { | |
let t = i as f64 * dt; | |
result.even_pts[i] = c.eval(t).to_point32(); | |
result.odd_pts[i] = c.eval(t + 0.5 * dt).to_point32(); | |
} | |
result.even_pts[n] = c.p3.to_point32(); | |
} | |
#[target_feature(enable = "avx2,fma")] | |
#[inline] | |
unsafe fn eval_cubics_avx2(avx2: Avx2, c: &CubicBez, n: usize, result: &mut FlattenCtx) { | |
result.n_quads = n; | |
let p0p1 = avx2.avx._mm256_loadu_pd(c as *const CubicBez as *const f64); | |
let p2p3 = avx2 | |
.avx | |
._mm256_loadu_pd((c as *const CubicBez as *const f64).add(4)); | |
let p0p1_32 = avx2.sse2._mm_castps_pd(avx2.avx._mm256_cvtpd_ps(p0p1)); | |
let p2p3_32 = avx2.sse2._mm_castps_pd(avx2.avx._mm256_cvtpd_ps(p2p3)); | |
const IOTA2: [f32; 8] = [0., 0., 2., 2., 1., 1., 3., 3.]; | |
let iota2: f32x8<Avx2> = IOTA2.simd_into(avx2); | |
let dt = 0.5 / n as f32; | |
let step = iota2 * dt; | |
let p0: f32x8<Avx2> = avx2 | |
.avx | |
._mm256_castpd_ps(avx2.avx2._mm256_broadcastsd_pd(p0p1_32)) | |
.simd_into(avx2); | |
let p1a = core::arch::x86_64::_mm_extract_epi64::<1>(avx2.sse2._mm_castpd_si128(p0p1_32)); | |
let p1: f32x8<Avx2> = avx2 | |
.avx | |
._mm256_castsi256_ps(avx2.avx._mm256_set1_epi64x(p1a)) | |
.simd_into(avx2); | |
let p2: f32x8<Avx2> = avx2 | |
.avx | |
._mm256_castpd_ps(avx2.avx2._mm256_broadcastsd_pd(p2p3_32)) | |
.simd_into(avx2); | |
let p3a = core::arch::x86_64::_mm_extract_epi64::<1>(avx2.sse2._mm_castpd_si128(p2p3_32)); | |
let p3: f32x8<Avx2> = avx2 | |
.avx | |
._mm256_castsi256_ps(avx2.avx._mm256_set1_epi64x(p3a)) | |
.simd_into(avx2); | |
let mut t = step; | |
let t_inc = avx2.splat_f32x8(4.0 * dt); | |
for i in 0..(n + 1) / 2 { | |
let mt = 1.0 - t; | |
let mt2 = mt * mt; | |
let t2 = t * t; | |
let z0 = mt2 * mt * p0; | |
let z1 = mt * p1; | |
let z2 = t.mul_add(p2, z1); | |
let z3 = z2 * (mt * 3.0); | |
let z4 = t2.mul_add(p3, z3); | |
let z = z4.mul_add(t, z0); | |
let (zlo, zhi) = avx2.split_f32x8(z); | |
avx2.sse | |
._mm_storeu_ps((result.even_pts.as_mut_ptr() as *mut f32).add(i * 4), zlo.into()); | |
avx2.sse | |
._mm_storeu_ps((result.odd_pts.as_mut_ptr() as *mut f32).add(i * 4), zhi.into()); | |
t = t + t_inc; | |
} | |
*(result.even_pts.as_mut_ptr() as *mut i64).add(n) = p3a; | |
} | |
fn eval_cubics_simd(level: Level, c: &CubicBez, n: usize, result: &mut FlattenCtx) { | |
if let Some(avx2) = level.as_avx2() { | |
unsafe { | |
eval_cubics_avx2(avx2, c, n, result); | |
} | |
return; | |
} | |
eval_cubics(c, n, result); | |
} | |
#[inline(always)] | |
fn approx_parabola_integral<S: Simd>(x: f32x8<S>) -> f32x8<S> { | |
const D: f32 = 0.67; | |
let x2 = x * x; | |
let t1_sqrt = x2.mul_add(0.25, D.powi(4)).sqrt(); | |
let t1_fourthroot = t1_sqrt.sqrt(); | |
let denom = t1_fourthroot + (1.0 - D); | |
x / denom | |
} | |
#[inline(always)] | |
fn approx_parabola_integral_x4<S: Simd>(x: f32x4<S>) -> f32x4<S> { | |
const D: f32 = 0.67; | |
let x2 = x * x; | |
let t1_sqrt = x2.mul_add(0.25, D.powi(4)).sqrt(); | |
let t1_fourthroot = t1_sqrt.sqrt(); | |
let denom = t1_fourthroot + (1.0 - D); | |
x / denom | |
} | |
#[inline(always)] | |
fn approx_parabola_inv_integral<S: Simd>(x: f32x8<S>) -> f32x8<S> { | |
const B: f32 = 0.39; | |
let x2 = x * x; | |
let t1_sqrt = x2.mul_add(0.25, B * B).sqrt(); | |
let factor = t1_sqrt + (1.0 - B); | |
x * factor | |
} | |
// TODO: move into library | |
#[inline(always)] | |
unsafe fn is_finite_f32x4<S: Simd>(x: f32x4<S>) -> mask32x4<S> { | |
// This is ok when simd_lt observes IEEE behavior around NaN. | |
// If that's not guaranteed, better to do comparison of bit patterns. | |
x.abs().simd_lt(f32::INFINITY) | |
} | |
// swizzle [x0, y0, x1, y1, x2, y2, x3, y3] | |
// to ([x0, x1, x2, x3, x0, x1, x2, x3], | |
// [y0, y1, y2, y3, y0, y1, y2, y3]) | |
#[target_feature(enable = "avx2,fma")] | |
#[inline] | |
unsafe fn unzipdup(x: f32x8<Avx2>) -> (f32x8<Avx2>, f32x8<Avx2>) { | |
let avx2 = x.simd; | |
let a = avx2.avx._mm256_permute_ps::<0xd8>(x.into()); | |
// a = [x0, x1, y0, y1, x2, x3, y2, y3] | |
let a1 = avx2.avx._mm256_castps_si256(a); | |
let x = avx2.avx2._mm256_permute4x64_epi64::<0x88>(a1); | |
let x1 = avx2.avx._mm256_castsi256_ps(x); | |
let y = avx2.avx2._mm256_permute4x64_epi64::<0xdd>(a1); | |
let y1 = avx2.avx._mm256_castsi256_ps(y); | |
(x1.simd_into(avx2), y1.simd_into(avx2)) | |
} | |
#[target_feature(enable = "avx2,fma")] | |
#[inline] | |
unsafe fn estimate_subdiv_avx2(avx2: Avx2, sqrt_tol: f32, ctx: &mut FlattenCtx) { | |
let n = ctx.n_quads; | |
for i in 0..(n + 3) / 4 { | |
let p0: f32x8<Avx2> = avx2 | |
.avx | |
._mm256_loadu_ps((ctx.even_pts.as_ptr() as *const f32).add(i * 8) as *const f32) | |
.simd_into(avx2); | |
let p_onehalf: f32x8<Avx2> = avx2 | |
.avx | |
._mm256_loadu_ps((ctx.odd_pts.as_ptr() as *const f32).add(i * 8) as *const f32) | |
.simd_into(avx2); | |
let p2: f32x8<Avx2> = avx2 | |
.avx | |
._mm256_loadu_ps((ctx.even_pts.as_ptr() as *const f32).add(i * 8 + 2) as *const f32) | |
.simd_into(avx2); | |
let x = p0 * -0.5; | |
let x1 = p_onehalf.mul_add(2.0, x); | |
let p1 = p2.mul_add(-0.5, x1); | |
avx2 | |
.avx | |
._mm256_storeu_ps((ctx.odd_pts.as_mut_ptr() as *const f32).add(i * 8) as *mut f32, p1.into()); | |
let d01 = p1 - p0; | |
let d12 = p2 - p1; | |
let (d01x, d01y) = unzipdup(d01); | |
let (d12x, d12y) = unzipdup(d12); | |
let ddx = d01x - d12x; | |
let ddy = d01y - d12y; | |
let d02x = d01x + d12x; | |
let d02y = d01y + d12y; | |
// TODO: wire up mul_sub (vfnmadd) | |
let cross = d02x * ddy - d02y * ddx; | |
// TODO: these will be smoother when impl'ed | |
let d01d12x = avx2.combine_f32x4(avx2.split_f32x8(d01x).0, avx2.split_f32x8(d12x).0); | |
let d01d12y = avx2.combine_f32x4(avx2.split_f32x8(d01y).0, avx2.split_f32x8(d12y).0); | |
let x0_x2 = d01d12y.mul_add(ddy, d01d12x * ddx) / cross; | |
let ddxlo = avx2.split_f32x8(ddx).0; | |
let ddylo = avx2.split_f32x8(ddy).0; | |
let dd_hypot = ddylo.mul_add(ddylo, ddxlo * ddxlo).sqrt(); | |
let (x0, x2) = avx2.split_f32x8(x0_x2); | |
let scale_denom = dd_hypot * (x2 - x0); | |
let scale = (avx2.split_f32x8(cross).0 / scale_denom).abs(); | |
let a0_a2 = approx_parabola_integral(x0_x2); | |
let (a0, a2) = avx2.split_f32x8(a0_a2); | |
let da = a2 - a0; | |
let da_abs = da.abs(); | |
let sqrt_scale = scale.sqrt(); | |
let mask = avx2.sse._mm_xor_ps(x0.into(), x2.into()); | |
// Note: subsequent blend will use sign bit only, no need to >= 0 | |
let noncusp = da_abs * sqrt_scale; | |
let xmin = sqrt_tol / sqrt_scale; | |
let approxint = approx_parabola_integral_x4(xmin); | |
let cusp = (da_abs * sqrt_tol) / approxint; | |
// question: if we did a >= 0 comparison and select, would llvm optimize | |
// that away on sse? | |
let val_raw = avx2.sse4_1._mm_blendv_ps(noncusp.into(), cusp.into(), mask); | |
let val_is_finite = is_finite_f32x4(val_raw.simd_into(avx2)); | |
let val = avx2.sse._mm_and_ps(val_raw, avx2.sse2._mm_castsi128_ps(val_is_finite.into())); | |
let u0_u2 = approx_parabola_inv_integral(a0_a2); | |
let (u0, u2) = avx2.split_f32x8(u0_u2); | |
let uscale = 1.0 / (u2 - u0); | |
avx2.sse._mm_storeu_ps(ctx.a0.as_mut_ptr().add(i * 4), a0.into()); | |
avx2.sse._mm_storeu_ps(ctx.da.as_mut_ptr().add(i * 4), da.into()); | |
// TODO: should store -u0 * uscale | |
avx2.sse._mm_storeu_ps(ctx.u0.as_mut_ptr().add(i * 4), u0.into()); | |
avx2.sse._mm_storeu_ps(ctx.uscale.as_mut_ptr().add(i * 4), uscale.into()); | |
// question: should mask on <n ? | |
avx2.sse._mm_storeu_ps(ctx.val.as_mut_ptr().add(i * 4), val); | |
} | |
} | |
#[target_feature(enable = "avx2,fma")] | |
#[inline] | |
unsafe fn pt_splat(avx2: Avx2, pt: Point32) -> f32x8<Avx2> { | |
let p = avx2.avx._mm256_set1_epi64x(core::mem::transmute(pt)); | |
avx2.avx._mm256_castsi256_ps(p).simd_into(avx2) | |
} | |
#[target_feature(enable = "avx2,fma")] | |
#[inline] | |
unsafe fn output_lines(avx2: Avx2, ctx: &FlattenCtx, i: usize, x0: f32, dx: f32, n: usize, out: *mut f32) { | |
let p0 = pt_splat(avx2, ctx.even_pts[i]); | |
let p1 = pt_splat(avx2, ctx.odd_pts[i]); | |
let p2 = pt_splat(avx2, ctx.even_pts[i + 1]); | |
const IOTA2: [f32; 8] = [0., 0., 1., 1., 2., 2., 3., 3.]; | |
let iota2: f32x8<_> = IOTA2.simd_into(avx2); | |
let x = iota2.mul_add(dx, x0); | |
let da = ctx.da[i]; | |
let mut a = x.mul_add(da, ctx.a0[i]); | |
let a_inc = 4.0 * dx * da; | |
let uscale = ctx.uscale[i]; | |
for j in 0..(n + 3) / 4 { | |
let u = approx_parabola_inv_integral(a); | |
let t = u.mul_add(uscale, -ctx.u0[i] * uscale); | |
let mt = 1.0 - t; | |
let z = p0 * (mt * mt); | |
let z1 = p1.mul_add(2.0 * t * mt, z); | |
let p = p2.mul_add(t * t, z1); | |
avx2.avx._mm256_storeu_ps(out.add(j * 8), p.into()); | |
a = a + a_inc; | |
} | |
} | |
const TO_QUAD_TOL: f32 = 0.1; | |
#[target_feature(enable = "avx2,fma")] | |
unsafe fn flatten_cubic(avx2: Avx2, c: CubicBez, ctx: &mut FlattenCtx, accuracy: f32, result: &mut Vec<Point32>) { | |
let q_accuracy = (accuracy * TO_QUAD_TOL) as f64; | |
let max_hypot2 = 432.0 * q_accuracy * q_accuracy; | |
let p1x2 = c.p1.to_vec2() * 3.0 - c.p0.to_vec2(); | |
let p2x2 = c.p2.to_vec2() * 3.0 - c.p3.to_vec2(); | |
let err = (p2x2 - p1x2).hypot2(); | |
let mut n_quads = ((err / max_hypot2).powf(1. / 6.0).ceil() as usize).max(1); | |
n_quads = n_quads.min(MAX_QUADS); | |
eval_cubics_avx2(avx2, &c, n_quads, ctx); | |
let tol = accuracy * (1.0 - TO_QUAD_TOL); | |
let sqrt_tol = tol.sqrt(); | |
estimate_subdiv_avx2(avx2, sqrt_tol, ctx); | |
// This sum is SIMD'able but probably not worth bothering with | |
let sum: f32 = ctx.val[..n_quads].iter().sum(); | |
let n = ((0.5 * sum / sqrt_tol).ceil() as usize).max(1); | |
result.reserve(n + 4); | |
let step = sum / (n as f32); | |
let step_recip = 1.0 / step; | |
// val_sum is sum of val[..i] | |
let mut val_sum = 0.0; | |
let mut last_n = 0; | |
let out_ptr = result.as_mut_ptr().add(result.len()); | |
let mut x0base = 0.0; | |
for i in 0..n_quads { | |
let val = ctx.val[i]; | |
val_sum += val; | |
let this_n = val_sum * step_recip; | |
let this_n_next = 1.0 + this_n.floor(); | |
let dn = this_n_next as usize - last_n; | |
if dn > 0 { | |
let dx = step / val; | |
let x0 = x0base * dx; | |
output_lines(avx2, ctx, i, x0, dx, dn, out_ptr.add(last_n) as *mut f32); | |
} | |
x0base = this_n_next - this_n; | |
last_n = this_n_next as usize; | |
} | |
*out_ptr.add(n) = ctx.even_pts[n_quads]; | |
result.set_len(n + 1); | |
} | |
fn main() { | |
let c = CubicBez { | |
p0: Point::new(55.0, 466.0), | |
p1: Point::new(350.0, 146.), | |
p2: Point::new(496.0, 537.0), | |
p3: Point::new(739.0, 244.0), | |
}; | |
// We default this and it might be ok to persist, but if we're ok | |
// with unsafe, then MaybeUninit would be fastest. | |
let mut ctx = FlattenCtx::default(); | |
let level = Level::new(); | |
let start = std::time::Instant::now(); | |
let sqrt_tol = 0.25f32.sqrt(); | |
for _ in 0..1_000_000 { | |
eval_cubics_simd(level, &c, 10, &mut ctx); | |
if let Some(avx2) = level.as_avx2() { | |
unsafe { | |
estimate_subdiv_avx2(avx2, sqrt_tol, &mut ctx); | |
} | |
} | |
} | |
println!("elapsed: {:?}", start.elapsed()); | |
print!("M{} {}", ctx.even_pts[0].x, ctx.even_pts[0].y); | |
for i in 0..ctx.n_quads { | |
let p1 = ctx.odd_pts[i]; | |
let p2 = ctx.even_pts[i + 1]; | |
print!("Q{} {} {} {}", p1.x, p1.y, p2.x, p2.y); | |
} | |
println!(); | |
let mut result = vec![]; | |
let start2 = std::time::Instant::now(); | |
for _ in 0..1_000_000 { | |
result.clear(); | |
if let Some(avx2) = level.as_avx2() { | |
unsafe { | |
flatten_cubic(avx2, c, &mut ctx, 0.25, &mut result); | |
} | |
} | |
} | |
println!("elapsed: {:?}", start2.elapsed()); | |
//println!("{ctx:?}"); | |
print!("M{} {}", result[0].x, result[0].y); | |
for p in &result[1..] { | |
print!("L{} {}", p.x, p.y); | |
} | |
println!(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment