Last active
November 17, 2020 02:26
-
-
Save lexi-lambda/5bec3f33b1db4269fc129242b53b5f43 to your computer and use it in GitHub Desktop.
Inductive n-tensor representation proof of concept in Agda
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Matrix where | |
open import Data.Fin as Fin using (Fin; zero; suc) | |
open import Data.List as List using (List; []; _∷_; product) | |
open import Data.Nat as Nat | |
open import Data.Nat.Properties as Nat | |
open import Data.Product as Prod | |
open import Data.Vec as Vec using (Vec; []; _∷_) | |
open import Level using (Level) | |
open import Relation.Binary.PropositionalEquality using (_≡_; refl) | |
-- ----------------------------------------------------------------------------- | |
-- ranges | |
module Range where | |
record Range : Set where | |
field | |
offset : ℕ | |
length : ℕ | |
open Range public | |
[_,_⟩ : ℕ -> ℕ -> Range | |
[ i , j ⟩ = record { offset = i; length = j ∸ i } | |
[_,_] : ℕ -> ℕ -> Range | |
[ i , j ] = [ i , suc j ⟩ | |
bound₎ : Range -> ℕ | |
bound₎ r = offset r + length r | |
module FinRange where | |
open Range using (Range) | |
data FinRange : ℕ -> Set where | |
fromRange : (r : Range) -> (s : ℕ) -> FinRange (Range.bound₎ r + s) | |
toRange : ∀ {n} -> FinRange n -> Range | |
toRange (fromRange r _) = r | |
offset : ∀ {n} -> FinRange n -> ℕ | |
offset r = Range.offset (toRange r) | |
length : ∀ {n} -> FinRange n -> ℕ | |
length r = Range.length (toRange r) | |
slack : ∀ {n} -> FinRange n -> ℕ | |
slack (fromRange _ s) = s | |
[_,_⟩ : ∀ i j {s} -> FinRange (i + (j ∸ i) + s) | |
[ i , j ⟩ {s} = fromRange (Range.[ i , j ⟩) s | |
[_,_] : ∀ i j {s} -> FinRange (i + (suc j ∸ i) + s) | |
[ i , j ] {s} = fromRange (Range.[ i , j ]) s | |
bound₎ : ∀ {n} -> FinRange n -> ℕ | |
bound₎ r = Range.bound₎ (toRange r) | |
decomp : ∀ {n} -> (r : FinRange n) -> n ≡ offset r + length r + slack r | |
decomp (fromRange _ _) = refl | |
open FinRange using (FinRange; [_,_⟩; [_,_]) | |
sliceVec : {l : Level} -> {A : Set l} -> {n : _} -> Vec A n -> (r : FinRange n) -> Vec A (FinRange.length r) | |
sliceVec v r with FinRange.offset r | FinRange.length r | FinRange.slack r | FinRange.decomp r | |
sliceVec v _ | o | l | s | refl | |
rewrite +-assoc o l s = Vec.take l (Vec.drop o v) | |
-- ----------------------------------------------------------------------------- | |
-- basic definitions | |
data Mat (A : Set) : List ℕ -> Set where | |
scalar : A -> Mat A [] | |
vector : ∀ {n ns} -> Vec (Mat A ns) n -> Mat A (n ∷ ns) | |
toVec : ∀ {A ns} -> Mat A ns -> Vec A (product ns) | |
toVec (scalar x) = Vec.[ x ] | |
toVec (vector xs) = Vec.concat (Vec.map toVec xs) | |
fromVec : ∀ {A ns} -> Vec A (product ns) -> Mat A ns | |
fromVec {_} {[]} (x ∷ []) = scalar x | |
fromVec {_} {n ∷ ns} xs with Vec.group n (product ns) xs | |
... | ys , refl = vector (Vec.map fromVec ys) | |
reshape : ∀ {A ns ms} -> {{product ns ≡ product ms}} -> Mat A ns -> Mat A ms | |
reshape {{eq}} m with v <- toVec m rewrite eq = fromVec v | |
-- ----------------------------------------------------------------------------- | |
-- slicing | |
data Slice : List ℕ -> List ℕ -> Set where | |
all : ∀ {xs} -> Slice xs xs | |
index : ∀ {n xs ys} -> Fin n -> Slice xs ys -> Slice (n ∷ xs) ys | |
range : ∀ {n xs ys} -> (r : FinRange n) -> Slice xs ys -> Slice (n ∷ xs) (FinRange.length r ∷ ys) | |
slice : ∀ {A xs ys} -> Slice xs ys -> Mat A xs -> Mat A ys | |
slice all x = x | |
slice (index i s) (vector xs) = slice s (Vec.lookup xs i) | |
slice (range r s) (vector xs) = vector (Vec.map (slice s) (sliceVec xs r)) | |
-- ----------------------------------------------------------------------------- | |
-- examples | |
open import Data.Integer as Int using (ℤ) | |
mat₁ : Mat ℤ (3 ∷ []) | |
mat₁ = vector (scalar (Int.+ 1) ∷ scalar (Int.+ 2) ∷ scalar (Int.+ 3) ∷ []) | |
mat₂ : Mat ℤ (5 ∷ 4 ∷ 3 ∷ []) | |
mat₂ = fromVec (Vec.tabulate (λ n -> Int.+ (Fin.toℕ n))) | |
slice₁ : Slice (5 ∷ 4 ∷ 3 ∷ []) (3 ∷ 2 ∷ []) | |
slice₁ = range [ 2 , 5 ⟩ (index (suc zero) (range [ 0 , 2 ⟩ all)) | |
mat₃ : Mat ℤ (3 ∷ 2 ∷ []) | |
mat₃ = slice slice₁ mat₂ | |
mat₄ : Mat ℤ (6 ∷ []) | |
mat₄ = reshape mat₃ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment