Created
March 13, 2025 07:59
-
-
Save philnguyen/d10d044bd3e3449e9b9fef20ea592521 to your computer and use it in GitHub Desktop.
Rudimentary example of modeling units as types in Lean
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Lean.Data.PersistentHashMap | |
open Lean | |
---------------------------------------------------------------------------------------------------- | |
-- Units, as types, sloppily using strings as basic units | |
---------------------------------------------------------------------------------------------------- | |
abbrev NonZero := {n : Int // n ≠ 0} | |
private def Nat.supscript (n : Nat) : String := n.toSuperDigits.asString | |
private def Int.supscript (n : Int) : String := if n < 0 then s!"⁻{n.natAbs.supscript}" else n.natAbs.supscript | |
abbrev Components := PersistentHashMap String NonZero | |
private def Components.toString (c : Components) : String := | |
(c.foldl (λ ss c ⟨d, _⟩ => s!"{c}{d.supscript}" :: ss) []).toString | |
instance : ToString Components where toString := Components.toString | |
def Components.pointwise (op : Int → Int → Int) (u₁ u₂ : Components) : Components := | |
u₂.foldl (λ u c d => | |
let d₀ := match u.find? c with | |
| .some ⟨d₀, _⟩ => d₀ | |
| .none => 0 | |
let ⟨d₁, _⟩ := d | |
let d' := op d₀ d₁ | |
if canceled: d' = 0 then u.erase c else u.insert c ⟨d', by assumption⟩ | |
) | |
u₁ | |
instance : Mul Components where mul := Components.pointwise Int.add | |
instance : Div Components where div := Components.pointwise Int.sub | |
---------------------------------------------------------------------------------------------------- | |
-- Value with units | |
---------------------------------------------------------------------------------------------------- | |
structure unit (components : Components) where | |
amount : Float | |
instance : Add (unit c) where add | ⟨x⟩, ⟨y⟩ => ⟨x + y⟩ | |
instance : Sub (unit c) where sub | ⟨x⟩, ⟨y⟩ => ⟨x - y⟩ | |
instance : HMul (unit c₁) (unit c₂) (unit (c₁ * c₂)) where hMul | ⟨x⟩, ⟨y⟩ => ⟨x * y⟩ | |
instance : HDiv (unit c₁) (unit c₂) (unit (c₁ / c₂)) where hDiv | ⟨x⟩, ⟨y⟩ => ⟨x / y⟩ | |
notation:max c "¹" => PersistentHashMap.empty.insert c ⟨1, by simp⟩ | |
notation:max c "²" => PersistentHashMap.empty.insert c ⟨2, by simp⟩ | |
notation:max c "⁻¹" => PersistentHashMap.empty.insert c ⟨-1, by simp⟩ | |
notation:max c "⁻²" => PersistentHashMap.empty.insert c ⟨-2, by simp⟩ | |
instance : ToString (unit c) where | |
toString | ⟨x⟩ => s!"{x} {Components.toString c}" | |
abbrev scalar := unit .empty | |
---------------------------------------------------------------------------------------------------- | |
-- Examples | |
---------------------------------------------------------------------------------------------------- | |
def distance := (⟨42.0⟩ : unit "meter"¹) | |
def time := (⟨12⟩ : unit "second"¹) | |
def velocity := distance / time | |
#eval velocity -- 3.0 ["second"⁻¹, "meter"¹] | |
-- #eval distance + time -- Type error: cannot add values of incompatible units | |
-- def twiceVelocity : unit "second"¹ := velocity * velocity -- Type error: mismatch with annotation |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment