Skip to content

Instantly share code, notes, and snippets.

@Astrac
Created November 25, 2015 11:09
Show Gist options
  • Save Astrac/42a33fdc8c844f30de49 to your computer and use it in GitHub Desktop.
Save Astrac/42a33fdc8c844f30de49 to your computer and use it in GitHub Desktop.
A generic, typeclass-driven mk4 integrator
package astrac.engine
import cats._
import cats.syntax._
import scala.annotation.tailrec
object mk4 {
trait State[S, D, T] {
def semigroup: Semigroup[S]
def derivate: mk4.Derivate[D, T]
def scale(s: S, factor: Double): S
def fromDerivate(d: D, t: T): S
}
trait Derivate[D, T] {
def monoid: Monoid[D]
def scale(d: D, factor: Double): D
def time: mk4.Time[T]
}
trait Time[T] {
def monoid: Monoid[T]
def ordering: Ordering[T]
def half(t: T): T
def negate(t: T): T
def ratio(num: T, den: T): Double
}
def evaluate[S, D, T](
state: S,
derivate: (S, T) => D,
t: T,
dt: T,
lastDerivate: D
)(
implicit
st: State[S, D, T],
dr: Derivate[D, T],
tm: Time[T]
): D = {
derivate(
st.semigroup.combine(state, st.fromDerivate(lastDerivate, dt)),
tm.monoid.combine(t, dt)
)
}
trait Integrable[S, D, T] {
implicit def state: State[S, D, T]
implicit def derivate: Derivate[D, T]
implicit def time: Time[T]
}
object Integrable {
implicit def fromSDT[S, D, T](implicit st: State[S, D, T], dr: Derivate[D, T], tm: Time[T]): Integrable[S, D, T] = new Integrable[S, D, T] {
override implicit def state = st
override implicit def derivate = dr
override implicit def time = tm
}
}
def step[S, D, T](initial: S, fn: (S, T) => D, t: T, dt: T)(
implicit
integrable: Integrable[S, D, T]
): S = {
import integrable._
val a = evaluate(initial, fn, t, time.monoid.empty, derivate.monoid.empty)
val b = evaluate(initial, fn, t, time.half(dt), a)
val c = evaluate(initial, fn, t, time.half(dt), b)
val d = evaluate(initial, fn, t, dt, c)
val dxdt = derivate.scale(derivate.monoid.combineAll(
a ::
derivate.scale(derivate.monoid.combine(b, c), 2) ::
d ::
Nil
), 1.0 / 6.0)
state.semigroup.combine(initial, state.fromDerivate(dxdt, dt))
}
class Stepper[S, T, D](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) {
def in(t: T, dt: T)(initial: S) = step(initial, fn, t, dt)
}
class Integrator[S, T, D](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) {
val stepper = mk4.stepper(fn)
case class Step(
current: S,
previous: Option[S],
frameTime: T,
symTime: T,
accumulator: T
)
def consume(
initial: S,
symTime: T,
dt: T,
accumulator: T
): (S, Option[S], T, T) = {
@tailrec
def consumeAcc(
current: S,
previous: Option[S],
t: T,
accumulator: T
): (S, Option[S], T, T) =
if (int.time.ordering.lt(accumulator, dt))
(current, previous, t, accumulator)
else consumeAcc(
stepper.in(t, dt)(current),
Some(current),
int.time.monoid.combine(t, dt),
int.time.monoid.combine(accumulator, int.time.negate(dt))
)
consumeAcc(initial, None, symTime, accumulator)
}
def integrate(
initial: S,
startTime: T,
samplingTimes: Iterable[T],
minDt: T,
maxDt: T
): Iterable[S] =
samplingTimes
.scanLeft(
Step(initial, None, startTime, startTime, int.time.monoid.empty)
) { (lastStep, newFrameTime) =>
val accumulator = int.time.monoid.combine(
lastStep.accumulator,
int.time.ordering.min(
maxDt,
int.time.monoid.combine(
newFrameTime,
int.time.negate(lastStep.frameTime)
)
)
)
val (newState, prevState, newSymTime, newAccumulator) =
consume(lastStep.current, lastStep.symTime, minDt, accumulator)
Step(
newState,
prevState orElse Some(lastStep.current),
newFrameTime,
newSymTime,
newAccumulator
)
}
.map { step =>
step.previous.fold(step.current) { previous =>
val alpha = int.time.ratio(step.accumulator, minDt)
int.state.semigroup.combine(
int.state.scale(step.current, alpha),
int.state.scale(previous, 1 - alpha)
)
}
}
}
def stepper[S, D, T](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) = new Stepper(fn)
def integrator[S, D, T](fn: (S, T) => D)(implicit int: Integrable[S, D, T]) = new Integrator(fn)
val doubleMonoid = new Monoid[Double] {
val empty = 0.0
def combine(x: Double, y: Double) = x + y
}
implicit val doubleTime = new Time[Double] {
def monoid = doubleMonoid
def ordering = implicitly[Ordering[Double]]
def half(t: Double) = t * 0.5
def negate(t: Double) = -t
def ratio(num: Double, den: Double) = num / den
}
implicit val doubleDerivate = new Derivate[Double, Double] {
def monoid = doubleMonoid
def scale(d: Double, factor: Double) = d * factor
def time = doubleTime
}
implicit val doubleState = new State[Double, Double, Double] {
def semigroup = doubleMonoid
def derivate = doubleDerivate
def scale(s: Double, factor: Double) = s * factor
def fromDerivate(d: Double, t: Double) = d * t
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment