Last active
April 9, 2024 10:42
-
-
Save Rukh/96cbe8b93cdf4976c0c6a367f236fbbd to your computer and use it in GitHub Desktop.
A representation of a unit Bezier curve for `BinaryFloatingPoint` types. This struct allows solving the Bezier curve at a given `x` value, which is useful for applying timing functions directly to numerical calculations.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// CAMediaTimingFunction+Unit.swift | |
// | |
// Created by Dmitry Gulyagin on 08/04/2024. | |
// https://gist.github.com/Rukh/96cbe8b93cdf4976c0c6a367f236fbbd | |
// | |
import QuartzCore | |
extension CAMediaTimingFunction { | |
/// A representation of a unit Bezier curve for `BinaryFloatingPoint` types. | |
/// This struct allows solving the Bezier curve at a given `x` value, | |
/// which is useful for applying timing functions directly to numerical calculations. | |
/// | |
/// The implementation is based on the mathematical concept of a unit Bezier curve, | |
/// as detailed in the source: [Apple/UnitBezier.h](https://opensource.apple.com/source/WebCore/WebCore-955.66/platform/graphics/UnitBezier.h) | |
/// | |
/// Example: | |
/// ```swift | |
/// let function = CAMediaTimingFunction(name: .easeInEaseOut) | |
/// // Store it if you use it multiple times | |
/// let unit = function.unit(type: Double.self) | |
/// let examples = [0.1, 0.3, 0.5, 0.7, 0.9] | |
/// let result = examples.map { unit.solve(x: $0).formatted() } | |
/// print(result) | |
/// // ["0.019722", "0.187396", "0.5", "0.812604", "0.980278"] | |
/// ``` | |
struct Unit<T: BinaryFloatingPoint> { | |
private let ax: T | |
private let bx: T | |
private let cx: T | |
private let ay: T | |
private let by: T | |
private let cy: T | |
init(p1: CGPoint, p2: CGPoint) where T == CGFloat.NativeType { | |
self.init(p1x: p1.x, p1y: p1.y, p2x: p2.x, p2y: p2.y) | |
} | |
/// Calculate the polynomial coefficients, implicit first and last control points are (0,0) and (1,1). | |
init(p1x: T, p1y: T, p2x: T, p2y: T) { | |
cx = 3.0 * p1x | |
bx = 3.0 * (p2x - p1x) - cx | |
ax = 1.0 - cx - bx | |
cy = 3.0 * p1y | |
by = 3.0 * (p2y - p1y) - cy | |
ay = 1.0 - cy - by | |
} | |
func solve(x: T, epsilon: T = 1e-6) -> T { | |
return sampleCurveY(t: solveCurveX(x: x, epsilon: epsilon)) | |
} | |
private func sampleCurveX(t: T) -> T { | |
return ((ax * t + bx) * t + cx) * t | |
} | |
private func sampleCurveY(t: T) -> T { | |
return ((ay * t + by) * t + cy) * t | |
} | |
private func sampleCurveDerivativeX(t: T) -> T { | |
return (3.0 * ax * t + 2.0 * bx) * t + cx | |
} | |
/// Given an x value, find a parametric value it came from. | |
private func solveCurveX(x: T, epsilon: T) -> T { | |
var t0: T | |
var t1: T | |
var t2: T | |
var x2: T | |
var d2: T | |
// First try a few iterations of Newton's method -- normally very fast. | |
t2 = x | |
for _ in 0 ..< 8 { | |
x2 = sampleCurveX(t: t2) - x | |
if abs(x2) < epsilon { | |
return t2 | |
} | |
d2 = sampleCurveDerivativeX(t: t2) | |
if abs(d2) < 1e-6 { | |
break | |
} | |
t2 = t2 - x2 / d2 | |
} | |
// Fall back to the bisection method for reliability. | |
t0 = 0.0 | |
t1 = 1.0 | |
t2 = x | |
if t2 < t0 { return t0 } | |
if t2 > t1 { return t1 } | |
while t0 < t1 { | |
x2 = sampleCurveX(t: t2) | |
if abs(x2 - x) < epsilon { | |
return t2 | |
} | |
if x > x2 { | |
t0 = t2 | |
} else { | |
t1 = t2 | |
} | |
t2 = (t1 - t0) * 0.5 + t0 | |
} | |
// Failure. | |
return t2 | |
} | |
} | |
func unit<T: BinaryFloatingPoint>( | |
type: T.Type = Float.self | |
) -> Unit<T> { | |
let controlPoints: [Float] = Array( | |
unsafeUninitializedCapacity: 4, | |
initializingWith: { buffer, initializedCount in | |
getControlPoint(at: 1, values: &buffer[0]) | |
getControlPoint(at: 2, values: &buffer[2]) | |
initializedCount = 4 | |
} | |
) | |
return Unit( | |
p1x: T(controlPoints[0]), | |
p1y: T(controlPoints[1]), | |
p2x: T(controlPoints[2]), | |
p2y: T(controlPoints[3]) | |
) | |
} | |
} | |
extension CAMediaTimingFunction.Unit: Sendable where T : Sendable { } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment