Skip to content

Instantly share code, notes, and snippets.

@philnguyen
Created March 13, 2025 07:59
Show Gist options
  • Save philnguyen/d10d044bd3e3449e9b9fef20ea592521 to your computer and use it in GitHub Desktop.
Save philnguyen/d10d044bd3e3449e9b9fef20ea592521 to your computer and use it in GitHub Desktop.
Rudimentary example of modeling units as types in Lean
import Lean.Data.PersistentHashMap
open Lean
----------------------------------------------------------------------------------------------------
-- Units, as types, sloppily using strings as basic units
----------------------------------------------------------------------------------------------------
abbrev NonZero := {n : Int // n ≠ 0}
private def Nat.supscript (n : Nat) : String := n.toSuperDigits.asString
private def Int.supscript (n : Int) : String := if n < 0 then s!"⁻{n.natAbs.supscript}" else n.natAbs.supscript
abbrev Components := PersistentHashMap String NonZero
private def Components.toString (c : Components) : String :=
(c.foldl (λ ss c ⟨d, _⟩ => s!"{c}{d.supscript}" :: ss) []).toString
instance : ToString Components where toString := Components.toString
def Components.pointwise (op : Int → Int → Int) (u₁ u₂ : Components) : Components :=
u₂.foldl (λ u c d =>
let d₀ := match u.find? c with
| .some ⟨d₀, _⟩ => d₀
| .none => 0
let ⟨d₁, _⟩ := d
let d' := op d₀ d₁
if canceled: d' = 0 then u.erase c else u.insert c ⟨d', by assumption⟩
)
u₁
instance : Mul Components where mul := Components.pointwise Int.add
instance : Div Components where div := Components.pointwise Int.sub
----------------------------------------------------------------------------------------------------
-- Value with units
----------------------------------------------------------------------------------------------------
structure unit (components : Components) where
amount : Float
instance : Add (unit c) where add | ⟨x⟩, ⟨y⟩ => ⟨x + y⟩
instance : Sub (unit c) where sub | ⟨x⟩, ⟨y⟩ => ⟨x - y⟩
instance : HMul (unit c₁) (unit c₂) (unit (c₁ * c₂)) where hMul | ⟨x⟩, ⟨y⟩ => ⟨x * y⟩
instance : HDiv (unit c₁) (unit c₂) (unit (c₁ / c₂)) where hDiv | ⟨x⟩, ⟨y⟩ => ⟨x / y⟩
notation:max c "¹" => PersistentHashMap.empty.insert c ⟨1, by simp⟩
notation:max c "²" => PersistentHashMap.empty.insert c ⟨2, by simp⟩
notation:max c "⁻¹" => PersistentHashMap.empty.insert c ⟨-1, by simp⟩
notation:max c "⁻²" => PersistentHashMap.empty.insert c ⟨-2, by simp⟩
instance : ToString (unit c) where
toString | ⟨x⟩ => s!"{x} {Components.toString c}"
abbrev scalar := unit .empty
----------------------------------------------------------------------------------------------------
-- Examples
----------------------------------------------------------------------------------------------------
def distance := (⟨42.0⟩ : unit "meter"¹)
def time := (⟨12⟩ : unit "second"¹)
def velocity := distance / time
#eval velocity -- 3.0 ["second"⁻¹, "meter"¹]
-- #eval distance + time -- Type error: cannot add values of incompatible units
-- def twiceVelocity : unit "second"¹ := velocity * velocity -- Type error: mismatch with annotation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment