Skip to content

Instantly share code, notes, and snippets.

@Vindaar
Last active December 29, 2020 09:58
Show Gist options
  • Save Vindaar/ac49c9caf063c04caaf74dcae08eca79 to your computer and use it in GitHub Desktop.
Save Vindaar/ac49c9caf063c04caaf74dcae08eca79 to your computer and use it in GitHub Desktop.
Pretty printing of ND tensors supporting N > 3
import sequtils
type
Tensor[T] = object
size: int
shape: seq[int]
strides: seq[int]
data: seq[T]
proc parseStrides(shape: seq[int]): seq[int] =
var stride = 1
var res: int
for axis, idx in shape:
if axis == 0:
result.add stride
stride = shape[axis]
else:
result.add stride
stride *= shape[axis]
proc initTensor[T](shape: varargs[int]): Tensor[T] =
result.shape = @shape
result.size = result.shape.foldl(a * b, 1)
result.data = newSeq[T](result.size)
result.strides = parseStrides(result.shape)
echo result
proc toTensor[T](x: openArray[T], shape: varargs[int]): Tensor[T] =
## x has to be 1D
result = initTensor[T](shape)
for i, el in x:
result.data[i] = el
proc computeIdx[T](t: Tensor[T], axis: int, i, j: int): int =
if axis == 0:
result = i * t.shape[axis] + j
else:
result = i + t.strides[axis] * j
proc stridedPrinting[T](t: Tensor[T], axis: int) =
## iterate each element along `axis` and print
let axLen = t.shape[axis]
let numToStride = t.size div axLen
var idx = 0
var res = ""
for i in 0 ..< numToStride:
for j in 0 ..< axLen:
let idx = computeIdx(t, axis, i, j)
res.add($t.data[idx] & ", ")
res.add "\n"
echo res
proc pretty[T](t: Tensor[T]): string =
let axLen = t.shape[0]
let numToStride = t.size div axLen
var idx = 0
var res = ""
for i in 0 ..< numToStride:
for j in 0 ..< axLen:
let idx = computeIdx(t, 0, i, j)
var toModBy = ""
var anyDim = false
for l in 2 ..< t.shape.len:
let toMod = t.shape[0 ..< l].foldl(a * b, 1)
if idx mod toMod == 0:
anyDim = true
if toModBy.len == 0:
toModBy = "\n"
toModBy.add("(dim: " & $l & ", idx: " & $((idx div toMod) mod t.shape[l]) & ")")
if anyDim and idx mod toMod != 0:
toModBy.add("(dim: " & $l & ", idx: " & $((idx div toMod) mod t.shape[l]) & ")")
if toModBy.len > 0:
toModBy.add "\n"
res.add toModBy
res.add($t.data[idx] & ", ")
res.add "\n"
result = res
let t = toSeq(1 .. 9).toTensor(3, 3)
stridedPrinting(t, axis = 0)
stridedPrinting(t, axis = 1)
let t2 = toSeq(1 .. 18).toTensor(3, 3, 2)
stridedPrinting(t2, axis = 0)
stridedPrinting(t2, axis = 1)
stridedPrinting(t2, axis = 2)
let t3 = toSeq(1 .. 36).toTensor(3, 3, 2, 2)
stridedPrinting(t3, axis = 0)
stridedPrinting(t3, axis = 1)
stridedPrinting(t3, axis = 2)
stridedPrinting(t3, axis = 3)
let t4 = toSeq(1 .. 144).toTensor(3, 3, 2, 2, 2, 2)
echo t4.pretty()
@Vindaar
Copy link
Author

Vindaar commented Dec 29, 2020

Output of stridedPrinting for t:

(size: 9, shape: @[3, 3], strides: @[1, 3], data: @[0, 0, 0, 0, 0, 0, 0, 0, 0])
1, 2, 3,                                                   
4, 5, 6,                                                                                                               
7, 8, 9,    
                                                                                                                       
1, 4, 7,       
2, 5, 8,                                                   
3, 6, 9,                                                                                                               

Output for strided printing for t2:

(size: 18, shape: @[3, 3, 2], strides: @[1, 3, 9], data: @[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
1, 2, 3, 
4, 5, 6, 
7, 8, 9, 
10, 11, 12, 
13, 14, 15, 
16, 17, 18, 

1, 4, 7, 
2, 5, 8, 
3, 6, 9, 
4, 7, 10, 
5, 8, 11, 
6, 9, 12, 

1, 10, 
2, 11, 
3, 12, 
4, 13, 
5, 14, 
6, 15, 
7, 16, 
8, 17, 
9, 18, 

Output for strided printing for t3:

(size: 36, shape: @[3, 3, 2, 2], strides: @[1, 3, 9, 18], data: @[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
1, 2, 3, 
4, 5, 6, 
7, 8, 9, 
10, 11, 12, 
13, 14, 15, 
16, 17, 18, 
19, 20, 21, 
22, 23, 24, 
25, 26, 27, 
28, 29, 30, 
31, 32, 33, 
34, 35, 36, 

1, 4, 7, 
2, 5, 8, 
3, 6, 9, 
4, 7, 10, 
5, 8, 11, 
6, 9, 12, 
7, 10, 13, 
8, 11, 14, 
9, 12, 15, 
10, 13, 16, 
11, 14, 17, 
12, 15, 18, 

1, 10, 
2, 11, 
3, 12, 
4, 13, 
5, 14, 
6, 15, 
7, 16, 
8, 17, 
9, 18, 
10, 19, 
11, 20, 
12, 21, 
13, 22, 
14, 23, 
15, 24, 
16, 25, 
17, 26, 
18, 27, 

1, 19, 
2, 20, 
3, 21, 
4, 22, 
5, 23, 
6, 24, 
7, 25, 
8, 26, 
9, 27, 
10, 28, 
11, 29, 
12, 30, 
13, 31, 
14, 32, 
15, 33, 
16, 34, 
17, 35, 
18, 36, 

"Pretty" printing of t4:

(size: 144, shape: @[3, 3, 2, 2, 2, 2], strides: @[1, 3, 9, 18, 36, 72], data: @[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
                                                           
(dim: 2, idx: 0)(dim: 3, idx: 0)(dim: 4, idx: 0)(dim: 5, idx: 0)                                                       
1, 2, 3,                                                   
4, 5, 6,                                                                                                               
7, 8, 9,       
                                                           
(dim: 2, idx: 1)(dim: 3, idx: 0)(dim: 4, idx: 0)(dim: 5, idx: 0)                                                       
10, 11, 12,                                                
13, 14, 15,                                                                                                            
16, 17, 18,    
                                                           
(dim: 2, idx: 0)(dim: 3, idx: 1)(dim: 4, idx: 0)(dim: 5, idx: 0)                                                       
19, 20, 21,                                                
22, 23, 24,                                                                                                            
25, 26, 27,    
                                                           
(dim: 2, idx: 1)(dim: 3, idx: 1)(dim: 4, idx: 0)(dim: 5, idx: 0)                                                       
28, 29, 30,                                                
31, 32, 33, 
34, 35, 36, 

(dim: 2, idx: 0)(dim: 3, idx: 0)(dim: 4, idx: 1)(dim: 5, idx: 0)
37, 38, 39, 
40, 41, 42, 
43, 44, 45, 

(dim: 2, idx: 1)(dim: 3, idx: 0)(dim: 4, idx: 1)(dim: 5, idx: 0)
46, 47, 48, 
49, 50, 51, 
52, 53, 54, 

(dim: 2, idx: 0)(dim: 3, idx: 1)(dim: 4, idx: 1)(dim: 5, idx: 0)
55, 56, 57, 
58, 59, 60, 
61, 62, 63, 

(dim: 2, idx: 1)(dim: 3, idx: 1)(dim: 4, idx: 1)(dim: 5, idx: 0)
64, 65, 66, 
67, 68, 69, 
70, 71, 72, 

(dim: 2, idx: 0)(dim: 3, idx: 0)(dim: 4, idx: 0)(dim: 5, idx: 1)
73, 74, 75, 
76, 77, 78, 
79, 80, 81, 

(dim: 2, idx: 1)(dim: 3, idx: 0)(dim: 4, idx: 0)(dim: 5, idx: 1)
82, 83, 84, 
85, 86, 87, 
88, 89, 90, 

(dim: 2, idx: 0)(dim: 3, idx: 1)(dim: 4, idx: 0)(dim: 5, idx: 1)
91, 92, 93, 
94, 95, 96, 
97, 98, 99, 

(dim: 2, idx: 1)(dim: 3, idx: 1)(dim: 4, idx: 0)(dim: 5, idx: 1)
100, 101, 102, 
103, 104, 105, 
106, 107, 108, 

(dim: 2, idx: 0)(dim: 3, idx: 0)(dim: 4, idx: 1)(dim: 5, idx: 1)
109, 110, 111, 
112, 113, 114, 
115, 116, 117, 

(dim: 2, idx: 1)(dim: 3, idx: 0)(dim: 4, idx: 1)(dim: 5, idx: 1)
118, 119, 120, 
121, 122, 123, 
124, 125, 126, 

(dim: 2, idx: 0)(dim: 3, idx: 1)(dim: 4, idx: 1)(dim: 5, idx: 1)
127, 128, 129, 
130, 131, 132, 
133, 134, 135, 

(dim: 2, idx: 1)(dim: 3, idx: 1)(dim: 4, idx: 1)(dim: 5, idx: 1)
136, 137, 138, 
139, 140, 141, 
142, 143, 144, 

(dim: 2, idx: 1)(dim: 3, idx: 1)(dim: 4, idx: 0)(dim: 5, idx: 1)
100, 101, 102, 
103, 104, 105, 
106, 107, 108, 

(dim: 2, idx: 0)(dim: 3, idx: 0)(dim: 4, idx: 1)(dim: 5, idx: 1)
109, 110, 111, 
112, 113, 114, 
115, 116, 117, 

(dim: 2, idx: 1)(dim: 3, idx: 0)(dim: 4, idx: 1)(dim: 5, idx: 1)
118, 119, 120, 
121, 122, 123, 
124, 125, 126, 

(dim: 2, idx: 0)(dim: 3, idx: 1)(dim: 4, idx: 1)(dim: 5, idx: 1)
127, 128, 129, 
130, 131, 132, 
133, 134, 135, 

(dim: 2, idx: 1)(dim: 3, idx: 1)(dim: 4, idx: 1)(dim: 5, idx: 1)
136, 137, 138, 
139, 140, 141, 
142, 143, 144, 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment