Created
November 25, 2015 11:09
-
-
Save Astrac/42a33fdc8c844f30de49 to your computer and use it in GitHub Desktop.
A generic, typeclass-driven mk4 integrator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package astrac.engine | |
import cats._ | |
import cats.syntax._ | |
import scala.annotation.tailrec | |
object mk4 { | |
trait State[S, D, T] { | |
def semigroup: Semigroup[S] | |
def derivate: mk4.Derivate[D, T] | |
def scale(s: S, factor: Double): S | |
def fromDerivate(d: D, t: T): S | |
} | |
trait Derivate[D, T] { | |
def monoid: Monoid[D] | |
def scale(d: D, factor: Double): D | |
def time: mk4.Time[T] | |
} | |
trait Time[T] { | |
def monoid: Monoid[T] | |
def ordering: Ordering[T] | |
def half(t: T): T | |
def negate(t: T): T | |
def ratio(num: T, den: T): Double | |
} | |
def evaluate[S, D, T]( | |
state: S, | |
derivate: (S, T) => D, | |
t: T, | |
dt: T, | |
lastDerivate: D | |
)( | |
implicit | |
st: State[S, D, T], | |
dr: Derivate[D, T], | |
tm: Time[T] | |
): D = { | |
derivate( | |
st.semigroup.combine(state, st.fromDerivate(lastDerivate, dt)), | |
tm.monoid.combine(t, dt) | |
) | |
} | |
trait Integrable[S, D, T] { | |
implicit def state: State[S, D, T] | |
implicit def derivate: Derivate[D, T] | |
implicit def time: Time[T] | |
} | |
object Integrable { | |
implicit def fromSDT[S, D, T](implicit st: State[S, D, T], dr: Derivate[D, T], tm: Time[T]): Integrable[S, D, T] = new Integrable[S, D, T] { | |
override implicit def state = st | |
override implicit def derivate = dr | |
override implicit def time = tm | |
} | |
} | |
def step[S, D, T](initial: S, fn: (S, T) => D, t: T, dt: T)( | |
implicit | |
integrable: Integrable[S, D, T] | |
): S = { | |
import integrable._ | |
val a = evaluate(initial, fn, t, time.monoid.empty, derivate.monoid.empty) | |
val b = evaluate(initial, fn, t, time.half(dt), a) | |
val c = evaluate(initial, fn, t, time.half(dt), b) | |
val d = evaluate(initial, fn, t, dt, c) | |
val dxdt = derivate.scale(derivate.monoid.combineAll( | |
a :: | |
derivate.scale(derivate.monoid.combine(b, c), 2) :: | |
d :: | |
Nil | |
), 1.0 / 6.0) | |
state.semigroup.combine(initial, state.fromDerivate(dxdt, dt)) | |
} | |
class Stepper[S, T, D](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) { | |
def in(t: T, dt: T)(initial: S) = step(initial, fn, t, dt) | |
} | |
class Integrator[S, T, D](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) { | |
val stepper = mk4.stepper(fn) | |
case class Step( | |
current: S, | |
previous: Option[S], | |
frameTime: T, | |
symTime: T, | |
accumulator: T | |
) | |
def consume( | |
initial: S, | |
symTime: T, | |
dt: T, | |
accumulator: T | |
): (S, Option[S], T, T) = { | |
@tailrec | |
def consumeAcc( | |
current: S, | |
previous: Option[S], | |
t: T, | |
accumulator: T | |
): (S, Option[S], T, T) = | |
if (int.time.ordering.lt(accumulator, dt)) | |
(current, previous, t, accumulator) | |
else consumeAcc( | |
stepper.in(t, dt)(current), | |
Some(current), | |
int.time.monoid.combine(t, dt), | |
int.time.monoid.combine(accumulator, int.time.negate(dt)) | |
) | |
consumeAcc(initial, None, symTime, accumulator) | |
} | |
def integrate( | |
initial: S, | |
startTime: T, | |
samplingTimes: Iterable[T], | |
minDt: T, | |
maxDt: T | |
): Iterable[S] = | |
samplingTimes | |
.scanLeft( | |
Step(initial, None, startTime, startTime, int.time.monoid.empty) | |
) { (lastStep, newFrameTime) => | |
val accumulator = int.time.monoid.combine( | |
lastStep.accumulator, | |
int.time.ordering.min( | |
maxDt, | |
int.time.monoid.combine( | |
newFrameTime, | |
int.time.negate(lastStep.frameTime) | |
) | |
) | |
) | |
val (newState, prevState, newSymTime, newAccumulator) = | |
consume(lastStep.current, lastStep.symTime, minDt, accumulator) | |
Step( | |
newState, | |
prevState orElse Some(lastStep.current), | |
newFrameTime, | |
newSymTime, | |
newAccumulator | |
) | |
} | |
.map { step => | |
step.previous.fold(step.current) { previous => | |
val alpha = int.time.ratio(step.accumulator, minDt) | |
int.state.semigroup.combine( | |
int.state.scale(step.current, alpha), | |
int.state.scale(previous, 1 - alpha) | |
) | |
} | |
} | |
} | |
def stepper[S, D, T](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) = new Stepper(fn) | |
def integrator[S, D, T](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) = new Integrator(fn) | |
val doubleMonoid = new Monoid[Double] { | |
val empty = 0.0 | |
def combine(x: Double, y: Double) = x + y | |
} | |
implicit val doubleTime = new Time[Double] { | |
def monoid = doubleMonoid | |
def ordering = implicitly[Ordering[Double]] | |
def half(t: Double) = t * 0.5 | |
def negate(t: Double) = -t | |
def ratio(num: Double, den: Double) = num / den | |
} | |
implicit val doubleDerivate = new Derivate[Double, Double] { | |
def monoid = doubleMonoid | |
def scale(d: Double, factor: Double) = d * factor | |
def time = doubleTime | |
} | |
implicit val doubleState = new State[Double, Double, Double] { | |
def semigroup = doubleMonoid | |
def derivate = doubleDerivate | |
def scale(s: Double, factor: Double) = s * factor | |
def fromDerivate(d: Double, t: Double) = d * t | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment